diff --git a/Project.toml b/Project.toml index 64cf6dd..91fbc88 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.9.2" +version = "0.9.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 71e9507..45e4c78 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -6,21 +6,31 @@ export eigen, eigvals!!, factorize, factorize!!, + gram_eigh_full, + gram_eigh_full!!, + gram_eigh_full_with_pinv, + gram_eigh_full_with_pinv!!, + invsqrt_diag_safe, + invsqrth_safe, lq, lq!!, orth, orth!!, polar, polar!!, + pow_diag_safe, + powh_safe, qr, qr!!, + sqrt_diag_safe, + sqrth_safe, svd, svd!!, svdvals, svdvals!! import MatrixAlgebraKit as MAK -using LinearAlgebra: LinearAlgebra, norm +using LinearAlgebra: LinearAlgebra, Diagonal, isdiag, norm for (f, f_full, f_compact) in ( (:qr, :qr_full, :qr_compact), @@ -74,6 +84,209 @@ for (eigvals, eigh_vals, eig_vals) in end end +function _clamp_kwargs_doc(arg::AbstractString) + return join( + ( + " - `atol::Real`: absolute clamping threshold. Default `0`.", + " - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype($arg)))^(3//4)` when `atol = 0`, else `0`.", + ), "\n" + ) +end + +""" + pow_diag_safe(D::AbstractMatrix, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p + +Raise a diagonal-structured matrix `D` to the power `p`. Diagonal entries +`d` of `MAK.diagview(D)` with `abs(d) < tol` are clamped to zero before +exponentiation, where `tol = max(atol, rtol * maximum(abs, diagview(D)))`. +Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g. +`p = 1//2`) and pass through for integer `p`, so the operation itself +enforces the PSD precondition per-power. Errors if `isdiag(D)` is `false`. + +The implementation extracts entries via `MAK.diagview` and rebuilds via +`MAK.diagonal`, so types extending those (e.g. graded or block diagonal) +automatically extend [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), +and the [`powh_safe`](@ref) family. + +## Keyword arguments + +$(_clamp_kwargs_doc("D")) +""" +function pow_diag_safe( + D::AbstractMatrix, p; + atol = zero(real(eltype(D))), + rtol = iszero(atol) ? eps(real(eltype(D)))^(3 // 4) : + zero(real(eltype(D))) + ) + isdiag(D) || throw( + ArgumentError("pow_diag_safe requires a diagonal-structured matrix") + ) + σ = MAK.diagview(D) + tol = max(atol, rtol * maximum(abs, σ; init = zero(real(eltype(D))))) + return MAK.diagonal(map(d -> abs(d) < tol ? zero(d) : real(d)^p, σ)) +end + +""" + sqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) + +Square root of a diagonal-structured matrix `D`, equivalent to +`pow_diag_safe(D, 1//2; atol, rtol)`. + +## Keyword arguments + +$(_clamp_kwargs_doc("D")) +""" +sqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) + +""" + invsqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) + +Inverse square root of a diagonal-structured matrix `D`, treating diagonal +entries below tolerance as zero (Moore-Penrose convention). Equivalent to +`pow_diag_safe(D, -1//2; atol, rtol)`. + +## Keyword arguments + +$(_clamp_kwargs_doc("D")) +""" +invsqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) + +""" + powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p + +Raise an approximately Hermitian positive semi-definite matrix to the +power `p`. For diagonal-structured `M` (`isdiag(M) == true`), dispatches +to [`pow_diag_safe`](@ref) and skips the eigendecomposition. Otherwise, +computes via `M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(_clamp_kwargs_doc("M")) +""" +function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...) + isdiag(M) && return pow_diag_safe(M, p; kwargs...) + D, V = MAK.eigh_full(M, MAK.select_algorithm(MAK.eigh_full, M, alg)) + return V * pow_diag_safe(D, p; kwargs...) * V' +end + +""" + sqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2) + +Square root of an approximately Hermitian positive semi-definite matrix. +Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(_clamp_kwargs_doc("M")) +""" +sqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, 1 // 2; kwargs...) + +""" + invsqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2) + +Inverse square root of an approximately Hermitian positive semi-definite +matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(_clamp_kwargs_doc("M")) +""" +invsqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, -1 // 2; kwargs...) + +for (gram, gram_with_pinv, eigh_full) in ( + (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), + (:gram_eigh_full!!, :gram_eigh_full_with_pinv!!, :eigh_full!), + ) + @eval begin + function $gram(A::AbstractMatrix; alg = nothing, kwargs...) + D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) + return sqrth_safe(D; kwargs...) * V' + end + function $gram_with_pinv(A::AbstractMatrix; alg = nothing, kwargs...) + D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) + return sqrth_safe(D; kwargs...) * V', V * invsqrth_safe(D; kwargs...) + end + end +end + +""" + gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X + gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X + +Gram factorization of a Hermitian positive semi-definite matrix via its +eigendecomposition: returns `X = sqrth_safe(D; atol, rtol) * V'` such +that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see +[`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may +destroy `A`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(_clamp_kwargs_doc("A")) + +# Examples + +```jldoctest +julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full + +julia> B = [1.0 0.5; 0.5 2.0]; + +julia> A = B' * B; + +julia> X = gram_eigh_full(A); + +julia> X' * X ≈ A +true +``` + +See also [`gram_eigh_full_with_pinv`](@ref). +""" +gram_eigh_full, gram_eigh_full!! + +""" + gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y + gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y + +Like [`gram_eigh_full`](@ref), but additionally returns +`Y = V * invsqrth_safe(D; atol, rtol) ≈ pinv(X)` so that `X * Y ≈ I` on +the rank subspace. Eigenvalues below `tol` are clamped to zero in both +factors. The `!!` variant may destroy `A`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(_clamp_kwargs_doc("A")) + +# Examples + +```jldoctest +julia> using LinearAlgebra: I + +julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv + +julia> B = [1.0 0.5; 0.5 2.0]; + +julia> A = B' * B; + +julia> X, Y = gram_eigh_full_with_pinv(A); + +julia> X' * X ≈ A +true + +julia> X * Y ≈ I +true +``` +""" +gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! + for (svd, svd_trunc, svd_full, svd_compact) in ( (:svd, :svd_trunc, :svd_full, :svd_compact), (:svd!!, :svd_trunc!, :svd_full!, :svd_compact!), diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 5202364..a10e57e 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,7 +1,8 @@ module TensorAlgebra -export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, - lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals +export contract, contract!, eigen, eigvals, factorize, gram_eigh_full, + gram_eigh_full_with_pinv, left_null, left_orth, left_polar, lq, qr, + right_null, right_orth, right_polar, orth, polar, svd, svdvals if VERSION >= v"1.11.0-DEV.469" eval(Meta.parse("public contractopadd!, matricizeop")) diff --git a/src/factorizations.jl b/src/factorizations.jl index 96fdc6f..6d5843d 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -31,6 +31,7 @@ end for f in ( :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null, + :gram_eigh_full, :gram_eigh_full_with_pinv, ) @eval begin function $f( @@ -83,7 +84,7 @@ end qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -103,7 +104,7 @@ qr lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -123,7 +124,7 @@ lq left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -141,7 +142,7 @@ left_polar right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -159,7 +160,7 @@ right_polar left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -177,7 +178,7 @@ left_orth right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -195,7 +196,7 @@ right_orth factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -213,7 +214,7 @@ factorize eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -256,7 +257,7 @@ end eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output is a vector of eigenvalues. ## Keyword arguments @@ -291,7 +292,7 @@ end svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -332,7 +333,7 @@ end svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output is a vector of singular values. See also `MatrixAlgebraKit.svd_vals!`. @@ -361,7 +362,7 @@ end left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. @@ -401,7 +402,7 @@ end right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. @@ -433,3 +434,125 @@ end function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...) return right_null!!(copy(A), ndims_codomain; kwargs...) end + +""" + gram_eigh_full(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X + gram_eigh_full(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X + gram_eigh_full(A::AbstractArray, ndims_codomain::Val; kwargs...) -> X + gram_eigh_full(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X + +Gram factorization of a generic N-dimensional array, interpreting it as a +Hermitian positive semi-definite linear map from the domain to the codomain +dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(MatrixAlgebra._clamp_kwargs_doc("A")) + +# Examples + +```jldoctest +julia> using TensorAlgebra: contract, gram_eigh_full + +julia> B = randn(3, 2, 2); + +julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> X = gram_eigh_full(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); + +julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d)) +true +``` + +See also [`gram_eigh_full_with_pinv`](@ref) and +[`MatrixAlgebra.gram_eigh_full`](@ref). +""" +gram_eigh_full + +function gram_eigh_full!!( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + A_mat = matricize(style, A, ndims_codomain) + X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) + axes_codomain = first(blocks(axes(A)[biperm])) + axes_X = tuplemortar(((axes(X, 1),), axes_codomain)) + return unmatricize(style, X, axes_X) +end +function gram_eigh_full!!(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full!!(FusionStyle(A), A, ndims_codomain; kwargs...) +end + +function gram_eigh_full( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + return gram_eigh_full!!(style, copy(A), ndims_codomain; kwargs...) +end +function gram_eigh_full(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full!!(copy(A), ndims_codomain; kwargs...) +end + +""" + gram_eigh_full_with_pinv(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, ndims_codomain::Val; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y + +Like [`gram_eigh_full`](@ref), but additionally returns `Y ≈ pinv(X)` such +that `X * Y ≈ I` on the rank subspace. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + +$(MatrixAlgebra._clamp_kwargs_doc("A")) + +# Examples + +```jldoctest +julia> using LinearAlgebra: I + +julia> using TensorAlgebra: contract, gram_eigh_full_with_pinv + +julia> B = randn(8, 2, 2); + +julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> X, Y = gram_eigh_full_with_pinv(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); + +julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d)) +true + +julia> contract((:r, :s), X, (:r, :a, :b), Y, (:a, :b, :s)) ≈ I +true +``` + +See also [`MatrixAlgebra.gram_eigh_full_with_pinv`](@ref). +""" +gram_eigh_full_with_pinv + +function gram_eigh_full_with_pinv!!( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + A_mat = matricize(style, A, ndims_codomain) + X, Y = MatrixAlgebra.gram_eigh_full_with_pinv!!(A_mat; kwargs...) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) + axes_codomain = first(blocks(axes(A)[biperm])) + axes_X = tuplemortar(((axes(X, 1),), axes_codomain)) + axes_Y = tuplemortar((axes_codomain, (axes(Y, 2),))) + return unmatricize(style, X, axes_X), unmatricize(style, Y, axes_Y) +end +function gram_eigh_full_with_pinv!!(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full_with_pinv!!(FusionStyle(A), A, ndims_codomain; kwargs...) +end + +function gram_eigh_full_with_pinv( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + return gram_eigh_full_with_pinv!!(style, copy(A), ndims_codomain; kwargs...) +end +function gram_eigh_full_with_pinv(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full_with_pinv!!(copy(A), ndims_codomain; kwargs...) +end diff --git a/test/test_exports.jl b/test/test_exports.jl index 0fb0d9b..5968127 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -9,6 +9,8 @@ using Test: @test, @testset :eigen, :eigvals, :factorize, + :gram_eigh_full, + :gram_eigh_full_with_pinv, :left_null, :left_orth, :left_polar, @@ -36,14 +38,24 @@ using Test: @test, @testset :eigvals!!, :factorize, :factorize!!, + :gram_eigh_full, + :gram_eigh_full!!, + :gram_eigh_full_with_pinv, + :gram_eigh_full_with_pinv!!, + :invsqrt_diag_safe, + :invsqrth_safe, :lq, :lq!!, :orth, :orth!!, :polar, :polar!!, + :pow_diag_safe, + :powh_safe, :qr, :qr!!, + :sqrt_diag_safe, + :sqrth_safe, :svd, :svd!!, :svdvals, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index b8eec74..2a353ac 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,8 @@ -using LinearAlgebra: LinearAlgebra, diag, norm +using LinearAlgebra: LinearAlgebra, I, diag, norm using MatrixAlgebraKit: truncrank -using TensorAlgebra: contract, eigen, eigvals, factorize, left_null, left_orth, left_polar, - lq, orth, polar, qr, right_null, right_orth, right_polar, svd, svdvals +using TensorAlgebra: contract, eigen, eigvals, factorize, gram_eigh_full, + gram_eigh_full_with_pinv, left_null, left_orth, left_polar, lq, orth, polar, qr, + right_null, right_orth, right_polar, svd, svdvals using Test: @test, @testset using TestExtras: @constinferred @@ -15,7 +16,7 @@ elts = (Float64, ComplexF64) labels_Q = (:b, :a) labels_R = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = true) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) @@ -35,7 +36,7 @@ end labels_Q = (:b, :a) labels_R = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = false) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) @@ -51,7 +52,7 @@ end labels_Q = (:d, :c) labels_L = (:b, :a) - Acopy = deepcopy(A) + Acopy = copy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = true) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) @@ -68,7 +69,7 @@ end labels_Q = (:d, :c) labels_L = (:b, :a) - Acopy = deepcopy(A) + Acopy = copy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = false) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) @@ -84,7 +85,7 @@ end labels_V = (:b, :a) labels_V′ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) # type-unstable because of `ishermitian` difference D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = false) @test A == Acopy # should not have altered initial array @@ -107,7 +108,7 @@ end labels_V = (:b, :a) labels_V′ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) # type-unstable because of `ishermitian` difference D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = true) @test A == Acopy # should not have altered initial array @@ -132,7 +133,7 @@ end labels_U = (:b, :a) labels_Vᴴ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(true)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) @@ -166,7 +167,7 @@ end labels_U = (:b, :a) labels_Vᴴ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(false)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) @@ -200,7 +201,7 @@ end labels_Vᴴ = (:d, :c) # test truncated SVD - Acopy = deepcopy(A) + Acopy = copy(A) _, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) trunc = truncrank(size(S_untrunc, 1) - 1) @@ -219,7 +220,7 @@ end labels_codomain = (:b, :a) labels_domain = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) N = @constinferred left_null(A, labels_A, labels_codomain, labels_domain) @test A == Acopy # should not have altered initial array # N^ba_n' * A^ba_dc = 0 @@ -243,7 +244,7 @@ end labels_W = (:b, :a) labels_P = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (W, P) in ( left_polar(A, labels_A, labels_W, labels_P), polar(A, labels_A, labels_W, labels_P; side = :left), @@ -262,7 +263,7 @@ end labels_P = (:b, :a) labels_W = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (P, W) in ( right_polar(A, labels_A, labels_P, labels_W), polar(A, labels_A, labels_P, labels_W; side = :right), @@ -280,7 +281,7 @@ end labels_W = (:b, :a) labels_P = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (W, P) in ( left_orth(A, labels_A, labels_W, labels_P), orth(A, labels_A, labels_W, labels_P; side = :left), @@ -299,7 +300,7 @@ end labels_P = (:b, :a) labels_W = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (P, W) in ( right_orth(A, labels_A, labels_P, labels_W), orth(A, labels_A, labels_P, labels_W; side = :right), @@ -317,7 +318,7 @@ end labels_X = (:b, :a) labels_Y = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for orth in (:left, :right) X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) @test A == Acopy # should not have altered initial array @@ -329,3 +330,56 @@ end @test A ≈ contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) end end + +# Gram factorization +# ------------------ +# Build a Hermitian positive semi-definite tensor A[a,b,c,d] with codomain +# (a, b) and domain (c, d): pick a random B[k, a, b] (k = aux), then form +# A = B' * B over k. By construction A ≈ X' * X for X[r, a, b] with rank r +# bounded by k (rank leg first, following the Cholesky `A = U' * U` +# convention). +@testset "Full-rank gram_eigh_full ($T)" for T in elts + B = randn(T, 6, 2, 3) # k = 6, codomain = (a, b) of size 2*3 = 6 -> full rank + A = contract((:a, :b, :c, :d), conj(B), (:k, :a, :b), B, (:k, :c, :d)) + labels_A = (:a, :b, :c, :d) + labels_X = (:a, :b) + labels_Y = (:c, :d) + + Acopy = copy(A) + X = @constinferred gram_eigh_full(A, labels_A, labels_X, labels_Y) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, conj(X), (:r, :a, :b), X, (:r, :c, :d)) + @test A ≈ A′ + @test size(X, 1) == size(A, 1) * size(A, 2) + + # `Val`, perm, and label entries agree. + @test gram_eigh_full(A, Val(2)) ≈ X + @test gram_eigh_full(A, (1, 2), (3, 4)) ≈ X + + # `with_pinv` variant: Y ≈ pinv(X) so X * Y ≈ I on the full-rank + # subspace. + X2, Y2 = @constinferred gram_eigh_full_with_pinv(A, labels_A, labels_X, labels_Y) + @test A ≈ contract(labels_A, conj(X2), (:r, :a, :b), X2, (:r, :c, :d)) + XY = contract((:r, :s), X2, (:r, :a, :b), Y2, (:a, :b, :s)) + @test XY ≈ I +end + +@testset "Rank-deficient gram_eigh_full ($T)" for T in elts + B = randn(T, 4, 2, 3) # k = 4 < codomain dim 6, so A is rank-4 + A = contract((:a, :b, :c, :d), conj(B), (:k, :a, :b), B, (:k, :c, :d)) + + # Recovery of A is independent of the `rtol` cutoff because all + # nonzero eigenvalues sit far above any reasonable threshold. + X = gram_eigh_full(A, Val(2); rtol = 1.0e-10) + @test A ≈ contract( + (:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d) + ) + + # Moore–Penrose-like identity: X * Y * X ≈ X when Y is pinv(X). With + # rank-first X and rank-last Y, contract X[r, a, b] * Y[a, b, s] → P[r, s] + # (projector onto the rank subspace), then P * X → X. + X2, Y2 = gram_eigh_full_with_pinv(A, Val(2); rtol = 1.0e-10) + P = contract((:r, :s), X2, (:r, :a, :b), Y2, (:a, :b, :s)) + PX = contract((:r, :c, :d), P, (:r, :s), X2, (:s, :c, :d)) + @test PX ≈ X2 +end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 1123120..d4a6c92 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -288,4 +288,76 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test size(ṽ) == (0, n) @test norm(ũ * s̃ * ṽ) ≈ 0 end + + @testset "gram_eigh_full" begin + n = 5 + # Full-rank Hermitian PSD. Use a tall random factor so `B' * B` + # is comfortably full rank even at Float32 precision (a square + # random `B` can produce a `B' * B` whose smallest eigenvalue + # falls below the default rtol clamp on some seeds). + rng = StableRNG(123) + B = randn(rng, elt, 2n, n) + A = B' * B + X = MatrixAlgebra.gram_eigh_full(A) + @test X' * X ≈ A + @test size(X) == (n, n) + + X2, Y2 = MatrixAlgebra.gram_eigh_full_with_pinv(A) + @test X2' * X2 ≈ A + @test X2 * Y2 ≈ I(n) + + # `!!` variant accepts a destroyable copy. + Xb = MatrixAlgebra.gram_eigh_full!!(copy(A)) + @test Xb' * Xb ≈ A + + # Rank deficient: A is n×n of rank k < n. Recovery of A still holds; + # X * Y is the projector onto the rank-k subspace (idempotent, + # rank k), and P * X ≈ X (Moore–Penrose). + k = 3 + Brd = randn(rng, elt, k, n) + Ard = Brd' * Brd + Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( + Ard; rtol = sqrt(eps(real(elt))) + ) + @test Xrd' * Xrd ≈ Ard + P = Xrd * Yrd + @test P * P ≈ P + @test P * Xrd ≈ Xrd + end + + @testset "powh_safe / sqrth_safe / invsqrth_safe" begin + n = 4 + rng = StableRNG(123) + B = randn(rng, elt, 2n, n) + A = B' * B + sqrtA = MatrixAlgebra.sqrth_safe(A) + @test sqrtA * sqrtA ≈ A + @test sqrtA ≈ sqrtA' + + invsqrtA = MatrixAlgebra.invsqrth_safe(A) + @test invsqrtA * sqrtA ≈ I(n) + + # Integer power: passes through without clamping affecting result. + @test MatrixAlgebra.powh_safe(A, 2) ≈ A * A + + # Diagonal-input dispatch path. + D = Diagonal(rand(real(elt), n)) + @test MatrixAlgebra.sqrth_safe(D) ≈ + MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `isdiag` fast-path: a dense matrix that happens to be diagonal- + # structured goes through `pow_diag_safe`, not `eigh_full`, and + # the result is still a `Diagonal` (from `MAK.diagonal`). + Mdiag = Matrix(D) + sqrtMdiag = MatrixAlgebra.powh_safe(Mdiag, 1 // 2) + @test sqrtMdiag isa Diagonal + @test sqrtMdiag ≈ MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `pow_diag_safe` directly on a diagonal-structured `AbstractMatrix`. + @test MatrixAlgebra.pow_diag_safe(Mdiag, 1 // 2) ≈ + MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `pow_diag_safe` rejects non-diagonal inputs. + @test_throws ArgumentError MatrixAlgebra.pow_diag_safe(A, 1 // 2) + end end