From ea3c54c9c03970c4f2f5e7b6ad90e7b7fa967917 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Thu, 28 May 2026 15:52:36 -0400 Subject: [PATCH 01/19] Add gram_eigh_full / gram_eigh_full_with_pinv factorizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a Hermitian-positive-semi-definite Gram factorization `A ≈ X * X'` (rank leg last, matching `eigen`/`svd`/`left_null`), plus a sibling that also returns `Y ≈ pinv(X)` so `Y * X ≈ I` on the rank subspace. Eigenvalues below a configurable tolerance are clamped to zero so the factorization is well defined for rank-deficient inputs. - `MatrixAlgebra.jl`: matrix-level primitives via a single `@eval` loop off `(eigh_full, eigh_full!)` (mirrors the existing `eigen` / `eigen!!` loop), plus `pinv_tol` and `sqrt_safe` helpers. - `factorizations.jl`: pair-returning `gram_eigh_full_with_pinv` joins the existing `qr` / `lq` / `factorize` `@eval` loop; both names join the existing perm / labels / biperm forwarder loop. Single-X `gram_eigh_full!!` is standalone, shape-matched to `left_null!!` / `right_null!!`. - Tests cover full-rank, rank-deficient, and `!!` variants across the matrix and tensor layers. Co-Authored-By: Claude Opus 4.7 (1M context) --- Project.toml | 2 +- src/MatrixAlgebra.jl | 81 ++++++++++++++++++++++++++++++++++++- src/TensorAlgebra.jl | 5 ++- src/factorizations.jl | 65 +++++++++++++++++++++++++++++ test/test_exports.jl | 6 +++ test/test_factorizations.jl | 57 ++++++++++++++++++++++++-- test/test_matrixalgebra.jl | 32 +++++++++++++++ 7 files changed, 241 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 64cf6ddd..91fbc88d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.9.2" +version = "0.9.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 71e9507f..445876fb 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -6,6 +6,10 @@ export eigen, eigvals!!, factorize, factorize!!, + gram_eigh_full, + gram_eigh_full!!, + gram_eigh_full_with_pinv, + gram_eigh_full_with_pinv!!, lq, lq!!, orth, @@ -20,7 +24,7 @@ export eigen, svdvals!! import MatrixAlgebraKit as MAK -using LinearAlgebra: LinearAlgebra, norm +using LinearAlgebra: LinearAlgebra, Diagonal, diag, norm for (f, f_full, f_compact) in ( (:qr, :qr_full, :qr_compact), @@ -74,6 +78,81 @@ for (eigvals, eigh_vals, eig_vals) in end end +""" + pinv_tol(λ; atol=0, rtol=...) -> tol + pinv_tol(λ, pinv::NamedTuple) -> tol + +Tolerance used by [`gram_eigh_full`](@ref) and +[`gram_eigh_full_with_pinv`](@ref) to clamp small eigenvalues to zero: +`tol = max(atol, rtol * maximum(abs, λ))`. The `NamedTuple` form splats +its fields as keyword arguments. +""" +pinv_tol(λ, pinv::NamedTuple) = pinv_tol(λ; pinv...) +function pinv_tol( + λ; atol = zero(real(eltype(λ))), + rtol = iszero(atol) ? eps(real(eltype(λ))) * length(λ) : + zero(real(eltype(λ))) + ) + return max(atol, rtol * maximum(abs, λ; init = zero(real(eltype(λ))))) +end + +""" + sqrt_safe(a::Number, tol=MatrixAlgebraKit.defaulttol(a)) + +Compute `sqrt(a)` when `abs(a) ≥ tol`, otherwise return `zero(a)`. +""" +sqrt_safe(a::Number, tol = MAK.defaulttol(a)) = abs(a) < tol ? zero(a) : sqrt(a) + +for (gram, gram_with_pinv, eigh_full) in ( + (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), + (:gram_eigh_full!!, :gram_eigh_full_with_pinv!!, :eigh_full!), + ) + @eval begin + function $gram(A::AbstractMatrix; alg = nothing, pinv = (;)) + D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) + λ = diag(D) + sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) + return V * Diagonal(sqrtλ) + end + function $gram_with_pinv(A::AbstractMatrix; alg = nothing, pinv = (;)) + D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) + λ = diag(D) + sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) + inv_sqrtλ = map(s -> iszero(s) ? s : inv(s), sqrtλ) + return V * Diagonal(sqrtλ), Diagonal(inv_sqrtλ) * V' + end + end +end + +""" + gram_eigh_full(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X + gram_eigh_full!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X + +Gram factorization of a Hermitian positive semi-definite matrix via its +eigendecomposition: returns `X = V * Diagonal(sqrt.(Λ))` such that +`A ≈ X * X'`, where `A = V * Diagonal(Λ) * V'`. Eigenvalues below +[`pinv_tol`](@ref) are clamped to zero. The `!!` variant may destroy `A`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + - `pinv::NamedTuple`: forwarded to [`pinv_tol`](@ref) (e.g. `(; atol, rtol)`). + +See also [`gram_eigh_full_with_pinv`](@ref). +""" +gram_eigh_full, gram_eigh_full!! + +""" + gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X, Y + gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X, Y + +Like [`gram_eigh_full`](@ref), but additionally returns +`Y = Diagonal(inv.(sqrt.(Λ))) * V' ≈ pinv(X)` so that `Y * X ≈ I` on the +rank subspace. Eigenvalues below [`pinv_tol`](@ref) are clamped to zero +in both factors. The `!!` variant may destroy `A`. +""" +gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! + for (svd, svd_trunc, svd_full, svd_compact) in ( (:svd, :svd_trunc, :svd_full, :svd_compact), (:svd!!, :svd_trunc!, :svd_full!, :svd_compact!), diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index 5202364f..a10e57e6 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -1,7 +1,8 @@ module TensorAlgebra -export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar, - lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals +export contract, contract!, eigen, eigvals, factorize, gram_eigh_full, + gram_eigh_full_with_pinv, left_null, left_orth, left_polar, lq, qr, + right_null, right_orth, right_polar, orth, polar, svd, svdvals if VERSION >= v"1.11.0-DEV.469" eval(Meta.parse("public contractopadd!, matricizeop")) diff --git a/src/factorizations.jl b/src/factorizations.jl index 96fdc6f6..ca5a6f80 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -11,6 +11,7 @@ for (f, f_mat) in ( (:right_orth, :(MatrixAlgebraKit.right_orth)), (:orth, :(MatrixAlgebra.orth)), (:factorize, :(MatrixAlgebra.factorize)), + (:gram_eigh_full_with_pinv, :(MatrixAlgebra.gram_eigh_full_with_pinv)), ) @eval begin function $f(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) @@ -31,6 +32,7 @@ end for f in ( :qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize, :eigen, :eigvals, :svd, :svdvals, :left_null, :right_null, + :gram_eigh_full, :gram_eigh_full_with_pinv, ) @eval begin function $f( @@ -433,3 +435,66 @@ end function right_null(A::AbstractArray, ndims_codomain::Val; kwargs...) return right_null!!(copy(A), ndims_codomain; kwargs...) end + +""" + gram_eigh_full(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X + gram_eigh_full(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X + gram_eigh_full(A::AbstractArray, ndims_codomain::Val; kwargs...) -> X + gram_eigh_full(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X + +Gram factorization of a generic N-dimensional array, interpreting it as a +Hermitian positive semi-definite linear map from the domain to the codomain +indices. Returns `X` such that `A ≈ X * X'` (contracted on the rank leg). + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + - `pinv::NamedTuple`: tolerance options used to clamp small eigenvalues to + zero (see `MatrixAlgebra.pinv_tol`). + +See also [`gram_eigh_full_with_pinv`](@ref) and +`MatrixAlgebra.gram_eigh_full`. +""" +gram_eigh_full + +function gram_eigh_full!!( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + A_mat = matricize(style, A, ndims_codomain) + X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) + axes_codomain = first(blocks(axes(A)[biperm])) + axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) + return unmatricize(style, X, axes_X) +end +function gram_eigh_full!!(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full!!(FusionStyle(A), A, ndims_codomain; kwargs...) +end + +function gram_eigh_full( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + return gram_eigh_full!!(style, copy(A), ndims_codomain; kwargs...) +end +function gram_eigh_full(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full!!(copy(A), ndims_codomain; kwargs...) +end + +""" + gram_eigh_full_with_pinv(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, ndims_codomain::Val; kwargs...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y + +Like [`gram_eigh_full`](@ref), but additionally returns `Y ≈ pinv(X)` such +that `Y * X ≈ I` on the rank subspace. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + - `pinv::NamedTuple`: tolerance options used to clamp small eigenvalues to + zero in both `X` and `Y` (see `MatrixAlgebra.pinv_tol`). + +See also `MatrixAlgebra.gram_eigh_full_with_pinv`. +""" +gram_eigh_full_with_pinv diff --git a/test/test_exports.jl b/test/test_exports.jl index 0fb0d9b4..b095c9f7 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -9,6 +9,8 @@ using Test: @test, @testset :eigen, :eigvals, :factorize, + :gram_eigh_full, + :gram_eigh_full_with_pinv, :left_null, :left_orth, :left_polar, @@ -36,6 +38,10 @@ using Test: @test, @testset :eigvals!!, :factorize, :factorize!!, + :gram_eigh_full, + :gram_eigh_full!!, + :gram_eigh_full_with_pinv, + :gram_eigh_full_with_pinv!!, :lq, :lq!!, :orth, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index b8eec74b..1c8f1096 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,8 @@ -using LinearAlgebra: LinearAlgebra, diag, norm +using LinearAlgebra: LinearAlgebra, I, diag, norm using MatrixAlgebraKit: truncrank -using TensorAlgebra: contract, eigen, eigvals, factorize, left_null, left_orth, left_polar, - lq, orth, polar, qr, right_null, right_orth, right_polar, svd, svdvals +using TensorAlgebra: contract, eigen, eigvals, factorize, gram_eigh_full, + gram_eigh_full_with_pinv, left_null, left_orth, left_polar, lq, orth, polar, qr, + right_null, right_orth, right_polar, svd, svdvals using Test: @test, @testset using TestExtras: @constinferred @@ -329,3 +330,53 @@ end @test A ≈ contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) end end + +# Gram factorization +# ------------------ +# Build a Hermitian positive semi-definite tensor A[a,b,c,d] with codomain +# (a, b) and domain (c, d): pick a random B[a, b, k] (k = aux), then form +# A = B * B' over k. By construction A ≈ X * X' for X[a, b, r] with rank r +# bounded by k. +@testset "Full-rank gram_eigh_full ($T)" for T in elts + B = randn(T, 2, 3, 6) # k = 6, codomain = (a, b) of size 2*3 = 6 -> full rank + A = contract((:a, :b, :c, :d), B, (:a, :b, :k), conj(B), (:c, :d, :k)) + labels_A = (:a, :b, :c, :d) + labels_X = (:a, :b) + labels_Y = (:c, :d) + + Acopy = deepcopy(A) + X = @constinferred gram_eigh_full(A, labels_A, labels_X, labels_Y) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, X, (:a, :b, :r), conj(X), (:c, :d, :r)) + @test A ≈ A′ + @test size(X, 3) == size(A, 1) * size(A, 2) + + # `Val`, perm, and label entries agree. + @test gram_eigh_full(A, Val(2)) ≈ X + @test gram_eigh_full(A, (1, 2), (3, 4)) ≈ X + + # `with_pinv` variant: Y ≈ pinv(X) so Y * X ≈ I on the full-rank + # subspace. + X2, Y2 = @constinferred gram_eigh_full_with_pinv(A, labels_A, labels_X, labels_Y) + @test A ≈ contract(labels_A, X2, (:a, :b, :r), conj(X2), (:c, :d, :r)) + YX = contract((:r, :s), Y2, (:r, :a, :b), X2, (:a, :b, :s)) + @test YX ≈ I +end + +@testset "Rank-deficient gram_eigh_full ($T)" for T in elts + B = randn(T, 2, 3, 4) # k = 4 < codomain dim 6, so A is rank-4 + A = contract((:a, :b, :c, :d), B, (:a, :b, :k), conj(B), (:c, :d, :k)) + + # Recovery of A is independent of the `pinv` tolerance because all + # nonzero eigenvalues sit far above any reasonable pinv cutoff. + X = gram_eigh_full(A, Val(2); pinv = (; rtol = 1.0e-10)) + @test A ≈ contract( + (:a, :b, :c, :d), X, (:a, :b, :r), conj(X), (:c, :d, :r) + ) + + # Moore–Penrose-like identity: X * Y * X ≈ X when Y is pinv(X). + X2, Y2 = gram_eigh_full_with_pinv(A, Val(2); pinv = (; rtol = 1.0e-10)) + P = contract((:a, :b, :c, :d), X2, (:a, :b, :r), Y2, (:r, :c, :d)) + XPX = contract((:a, :b, :r), P, (:a, :b, :c, :d), X2, (:c, :d, :r)) + @test XPX ≈ X2 +end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 1123120d..7ecbeca6 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -288,4 +288,36 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test size(ṽ) == (0, n) @test norm(ũ * s̃ * ṽ) ≈ 0 end + + @testset "gram_eigh_full" begin + n = 5 + # Full-rank Hermitian PSD. + B = randn(elt, n, n) + A = B * B' + X = MatrixAlgebra.gram_eigh_full(A) + @test X * X' ≈ A + @test size(X) == (n, n) + + X2, Y2 = MatrixAlgebra.gram_eigh_full_with_pinv(A) + @test X2 * X2' ≈ A + @test Y2 * X2 ≈ I(n) + + # `!!` variant accepts a destroyable copy. + Xb = MatrixAlgebra.gram_eigh_full!!(copy(A)) + @test Xb * Xb' ≈ A + + # Rank deficient: A is n×n of rank k < n. Recovery of A still holds; + # Y * X is the projector onto the rank-k subspace (idempotent, + # rank k), and X * Y * X ≈ X (Moore–Penrose). + k = 3 + Brd = randn(elt, n, k) + Ard = Brd * Brd' + Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( + Ard; pinv = (; rtol = sqrt(eps(real(elt)))) + ) + @test Xrd * Xrd' ≈ Ard + P = Yrd * Xrd + @test P * P ≈ P + @test Xrd * P ≈ Xrd + end end From 8e5152feb2efd165c842393f6f7bdb3dddb025fa Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Thu, 28 May 2026 17:53:35 -0400 Subject: [PATCH 02/19] =?UTF-8?q?Switch=20gram=5Feigh=5Ffull=20to=20rank-f?= =?UTF-8?q?irst=20(A=20=E2=89=88=20X'X)=20convention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Returns `X = Diagonal(sqrt.(Λ)) * V'` such that `A ≈ X' * X`, with the rank axis as the first leg in the tensor layout (and `Y = V * Diagonal(inv.(sqrt.(Λ)))` such that `X * Y ≈ I` on the rank subspace). Motivation: this matches Julia stdlib `LinearAlgebra.cholesky` (`A = U' * U`), Matlab's `chol`, and the standard Gram-matrix expositions on Wikipedia. The rank-leg-last alternative — also defensible since `eigh_full`'s `V` comes out with eigenvectors as columns — is a less common reading and conflicted with the convention the downstream ITensorNetworksNext consumer was originally written against. Implementation notes: - `gram_eigh_full_with_pinv` no longer fits the `qr` / `lq` / `factorize` pair `@eval` loop in `factorizations.jl`: that loop produces `(codomain..., rank)` on the left factor and `(rank, domain...)` on the right, but rank-first gram needs `(rank, codomain...)` on `X` and `(codomain..., rank)` on `Y`. Pulled out as a standalone `!!` / non-bang FusionStyle method pair instead. The single-X `gram_eigh_full!!` shape also flips to `((rank,), codomain...)`. - Tests updated to assert `X' * X ≈ A`, `X * Y ≈ I`, and the rank-first Moore-Penrose form `P * X ≈ X` with `P = X * Y`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/MatrixAlgebra.jl | 13 ++++++++----- src/factorizations.jl | 32 ++++++++++++++++++++++++++---- test/test_factorizations.jl | 39 ++++++++++++++++++++----------------- test/test_matrixalgebra.jl | 18 ++++++++--------- 4 files changed, 66 insertions(+), 36 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 445876fb..aed9b753 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -112,14 +112,14 @@ for (gram, gram_with_pinv, eigh_full) in ( D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) λ = diag(D) sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) - return V * Diagonal(sqrtλ) + return Diagonal(sqrtλ) * V' end function $gram_with_pinv(A::AbstractMatrix; alg = nothing, pinv = (;)) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) λ = diag(D) sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) inv_sqrtλ = map(s -> iszero(s) ? s : inv(s), sqrtλ) - return V * Diagonal(sqrtλ), Diagonal(inv_sqrtλ) * V' + return Diagonal(sqrtλ) * V', V * Diagonal(inv_sqrtλ) end end end @@ -129,10 +129,13 @@ end gram_eigh_full!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X Gram factorization of a Hermitian positive semi-definite matrix via its -eigendecomposition: returns `X = V * Diagonal(sqrt.(Λ))` such that -`A ≈ X * X'`, where `A = V * Diagonal(Λ) * V'`. Eigenvalues below +eigendecomposition: returns `X = Diagonal(sqrt.(Λ)) * V'` such that +`A ≈ X' * X`, where `A = V * Diagonal(Λ) * V'`. Eigenvalues below [`pinv_tol`](@ref) are clamped to zero. The `!!` variant may destroy `A`. +The orientation follows Julia's `LinearAlgebra.cholesky` convention +(`A = U' * U`) and standard Gram-matrix expositions. + ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. @@ -147,7 +150,7 @@ gram_eigh_full, gram_eigh_full!! gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X, Y Like [`gram_eigh_full`](@ref), but additionally returns -`Y = Diagonal(inv.(sqrt.(Λ))) * V' ≈ pinv(X)` so that `Y * X ≈ I` on the +`Y = V * Diagonal(inv.(sqrt.(Λ))) ≈ pinv(X)` so that `X * Y ≈ I` on the rank subspace. Eigenvalues below [`pinv_tol`](@ref) are clamped to zero in both factors. The `!!` variant may destroy `A`. """ diff --git a/src/factorizations.jl b/src/factorizations.jl index ca5a6f80..545dd7c5 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -11,7 +11,6 @@ for (f, f_mat) in ( (:right_orth, :(MatrixAlgebraKit.right_orth)), (:orth, :(MatrixAlgebra.orth)), (:factorize, :(MatrixAlgebra.factorize)), - (:gram_eigh_full_with_pinv, :(MatrixAlgebra.gram_eigh_full_with_pinv)), ) @eval begin function $f(style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs...) @@ -444,7 +443,8 @@ end Gram factorization of a generic N-dimensional array, interpreting it as a Hermitian positive semi-definite linear map from the domain to the codomain -indices. Returns `X` such that `A ≈ X * X'` (contracted on the rank leg). +indices. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). +The orientation follows Julia's `LinearAlgebra.cholesky` convention. ## Keyword arguments @@ -464,7 +464,7 @@ function gram_eigh_full!!( X = MatrixAlgebra.gram_eigh_full!!(A_mat; kwargs...) biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) axes_codomain = first(blocks(axes(A)[biperm])) - axes_X = tuplemortar((axes_codomain, (axes(X, 2),))) + axes_X = tuplemortar(((axes(X, 1),), axes_codomain)) return unmatricize(style, X, axes_X) end function gram_eigh_full!!(A::AbstractArray, ndims_codomain::Val; kwargs...) @@ -487,7 +487,7 @@ end gram_eigh_full_with_pinv(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Like [`gram_eigh_full`](@ref), but additionally returns `Y ≈ pinv(X)` such -that `Y * X ≈ I` on the rank subspace. +that `X * Y ≈ I` on the rank subspace. ## Keyword arguments @@ -498,3 +498,27 @@ that `Y * X ≈ I` on the rank subspace. See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ gram_eigh_full_with_pinv + +function gram_eigh_full_with_pinv!!( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + A_mat = matricize(style, A, ndims_codomain) + X, Y = MatrixAlgebra.gram_eigh_full_with_pinv!!(A_mat; kwargs...) + biperm = trivialbiperm(ndims_codomain, Val(ndims(A))) + axes_codomain = first(blocks(axes(A)[biperm])) + axes_X = tuplemortar(((axes(X, 1),), axes_codomain)) + axes_Y = tuplemortar((axes_codomain, (axes(Y, 2),))) + return unmatricize(style, X, axes_X), unmatricize(style, Y, axes_Y) +end +function gram_eigh_full_with_pinv!!(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full_with_pinv!!(FusionStyle(A), A, ndims_codomain; kwargs...) +end + +function gram_eigh_full_with_pinv( + style::FusionStyle, A::AbstractArray, ndims_codomain::Val; kwargs... + ) + return gram_eigh_full_with_pinv!!(style, copy(A), ndims_codomain; kwargs...) +end +function gram_eigh_full_with_pinv(A::AbstractArray, ndims_codomain::Val; kwargs...) + return gram_eigh_full_with_pinv!!(copy(A), ndims_codomain; kwargs...) +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 1c8f1096..356c3110 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -334,12 +334,13 @@ end # Gram factorization # ------------------ # Build a Hermitian positive semi-definite tensor A[a,b,c,d] with codomain -# (a, b) and domain (c, d): pick a random B[a, b, k] (k = aux), then form -# A = B * B' over k. By construction A ≈ X * X' for X[a, b, r] with rank r -# bounded by k. +# (a, b) and domain (c, d): pick a random B[k, a, b] (k = aux), then form +# A = B' * B over k. By construction A ≈ X' * X for X[r, a, b] with rank r +# bounded by k (rank leg first, following the Cholesky `A = U' * U` +# convention). @testset "Full-rank gram_eigh_full ($T)" for T in elts - B = randn(T, 2, 3, 6) # k = 6, codomain = (a, b) of size 2*3 = 6 -> full rank - A = contract((:a, :b, :c, :d), B, (:a, :b, :k), conj(B), (:c, :d, :k)) + B = randn(T, 6, 2, 3) # k = 6, codomain = (a, b) of size 2*3 = 6 -> full rank + A = contract((:a, :b, :c, :d), conj(B), (:k, :a, :b), B, (:k, :c, :d)) labels_A = (:a, :b, :c, :d) labels_X = (:a, :b) labels_Y = (:c, :d) @@ -347,36 +348,38 @@ end Acopy = deepcopy(A) X = @constinferred gram_eigh_full(A, labels_A, labels_X, labels_Y) @test A == Acopy # should not have altered initial array - A′ = contract(labels_A, X, (:a, :b, :r), conj(X), (:c, :d, :r)) + A′ = contract(labels_A, conj(X), (:r, :a, :b), X, (:r, :c, :d)) @test A ≈ A′ - @test size(X, 3) == size(A, 1) * size(A, 2) + @test size(X, 1) == size(A, 1) * size(A, 2) # `Val`, perm, and label entries agree. @test gram_eigh_full(A, Val(2)) ≈ X @test gram_eigh_full(A, (1, 2), (3, 4)) ≈ X - # `with_pinv` variant: Y ≈ pinv(X) so Y * X ≈ I on the full-rank + # `with_pinv` variant: Y ≈ pinv(X) so X * Y ≈ I on the full-rank # subspace. X2, Y2 = @constinferred gram_eigh_full_with_pinv(A, labels_A, labels_X, labels_Y) - @test A ≈ contract(labels_A, X2, (:a, :b, :r), conj(X2), (:c, :d, :r)) - YX = contract((:r, :s), Y2, (:r, :a, :b), X2, (:a, :b, :s)) - @test YX ≈ I + @test A ≈ contract(labels_A, conj(X2), (:r, :a, :b), X2, (:r, :c, :d)) + XY = contract((:r, :s), X2, (:r, :a, :b), Y2, (:a, :b, :s)) + @test XY ≈ I end @testset "Rank-deficient gram_eigh_full ($T)" for T in elts - B = randn(T, 2, 3, 4) # k = 4 < codomain dim 6, so A is rank-4 - A = contract((:a, :b, :c, :d), B, (:a, :b, :k), conj(B), (:c, :d, :k)) + B = randn(T, 4, 2, 3) # k = 4 < codomain dim 6, so A is rank-4 + A = contract((:a, :b, :c, :d), conj(B), (:k, :a, :b), B, (:k, :c, :d)) # Recovery of A is independent of the `pinv` tolerance because all # nonzero eigenvalues sit far above any reasonable pinv cutoff. X = gram_eigh_full(A, Val(2); pinv = (; rtol = 1.0e-10)) @test A ≈ contract( - (:a, :b, :c, :d), X, (:a, :b, :r), conj(X), (:c, :d, :r) + (:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d) ) - # Moore–Penrose-like identity: X * Y * X ≈ X when Y is pinv(X). + # Moore–Penrose-like identity: X * Y * X ≈ X when Y is pinv(X). With + # rank-first X and rank-last Y, contract X[r, a, b] * Y[a, b, s] → P[r, s] + # (projector onto the rank subspace), then P * X → X. X2, Y2 = gram_eigh_full_with_pinv(A, Val(2); pinv = (; rtol = 1.0e-10)) - P = contract((:a, :b, :c, :d), X2, (:a, :b, :r), Y2, (:r, :c, :d)) - XPX = contract((:a, :b, :r), P, (:a, :b, :c, :d), X2, (:c, :d, :r)) - @test XPX ≈ X2 + P = contract((:r, :s), X2, (:r, :a, :b), Y2, (:a, :b, :s)) + PX = contract((:r, :c, :d), P, (:r, :s), X2, (:s, :c, :d)) + @test PX ≈ X2 end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 7ecbeca6..179f2ba5 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -295,29 +295,29 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) B = randn(elt, n, n) A = B * B' X = MatrixAlgebra.gram_eigh_full(A) - @test X * X' ≈ A + @test X' * X ≈ A @test size(X) == (n, n) X2, Y2 = MatrixAlgebra.gram_eigh_full_with_pinv(A) - @test X2 * X2' ≈ A - @test Y2 * X2 ≈ I(n) + @test X2' * X2 ≈ A + @test X2 * Y2 ≈ I(n) # `!!` variant accepts a destroyable copy. Xb = MatrixAlgebra.gram_eigh_full!!(copy(A)) - @test Xb * Xb' ≈ A + @test Xb' * Xb ≈ A # Rank deficient: A is n×n of rank k < n. Recovery of A still holds; - # Y * X is the projector onto the rank-k subspace (idempotent, - # rank k), and X * Y * X ≈ X (Moore–Penrose). + # X * Y is the projector onto the rank-k subspace (idempotent, + # rank k), and P * X ≈ X (Moore–Penrose). k = 3 Brd = randn(elt, n, k) Ard = Brd * Brd' Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( Ard; pinv = (; rtol = sqrt(eps(real(elt)))) ) - @test Xrd * Xrd' ≈ Ard - P = Yrd * Xrd + @test Xrd' * Xrd ≈ Ard + P = Xrd * Yrd @test P * P ≈ P - @test Xrd * P ≈ Xrd + @test P * Xrd ≈ Xrd end end From 5ae843352352cc4b84421dd515b6e162c739ea14 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Thu, 28 May 2026 20:08:35 -0400 Subject: [PATCH 03/19] Tidy gram_eigh_full docstrings Drop the Cholesky-convention sentence and switch "indices" to "dimensions" in the tensor-layer docstring. --- src/MatrixAlgebra.jl | 3 --- src/factorizations.jl | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index aed9b753..46f3db37 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -133,9 +133,6 @@ eigendecomposition: returns `X = Diagonal(sqrt.(Λ)) * V'` such that `A ≈ X' * X`, where `A = V * Diagonal(Λ) * V'`. Eigenvalues below [`pinv_tol`](@ref) are clamped to zero. The `!!` variant may destroy `A`. -The orientation follows Julia's `LinearAlgebra.cholesky` convention -(`A = U' * U`) and standard Gram-matrix expositions. - ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. diff --git a/src/factorizations.jl b/src/factorizations.jl index 545dd7c5..13d1d501 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -443,8 +443,7 @@ end Gram factorization of a generic N-dimensional array, interpreting it as a Hermitian positive semi-definite linear map from the domain to the codomain -indices. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). -The orientation follows Julia's `LinearAlgebra.cholesky` convention. +dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). ## Keyword arguments From b17041115e0d825fa352ffe7f95a1b39b330b175 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Thu, 28 May 2026 20:13:29 -0400 Subject: [PATCH 04/19] Use 'dimensions' instead of 'indices' in factorization docstrings --- src/factorizations.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index 13d1d501..48213bff 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -84,7 +84,7 @@ end qr(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Q, R Compute the QR decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -104,7 +104,7 @@ qr lq(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> L, Q Compute the LQ decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -124,7 +124,7 @@ lq left_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> W, P Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -142,7 +142,7 @@ left_polar right_polar(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> P, W Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -160,7 +160,7 @@ right_polar left_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> V, C Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -178,7 +178,7 @@ left_orth right_orth(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> C, V Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -196,7 +196,7 @@ right_orth factorize(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> X, Y Compute the decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -214,7 +214,7 @@ factorize eigen(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D, V Compute the eigenvalue decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -257,7 +257,7 @@ end eigvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> D Compute the eigenvalues of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output is a vector of eigenvalues. ## Keyword arguments @@ -292,7 +292,7 @@ end svd(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> U, S, Vᴴ Compute the SVD decomposition of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. ## Keyword arguments @@ -333,7 +333,7 @@ end svdvals(A::AbstractArray, biperm::AbstractBlockPermutation{2}) -> S Compute the singular values of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output is a vector of singular values. See also `MatrixAlgebraKit.svd_vals!`. @@ -362,7 +362,7 @@ end left_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> N Compute the left nullspace of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output satisfies `N' * A ≈ 0` and `N' * N ≈ I`. @@ -402,7 +402,7 @@ end right_null(A::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...) -> Nᴴ Compute the right nullspace of a generic N-dimensional array, by interpreting it as -a linear map from the domain to the codomain indices. These can be specified either via +a linear map from the domain to the codomain dimensions. These can be specified either via their labels or directly through a bi-permutation. The output satisfies `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. From b975b352cffcf7083bab86f7b0e42fa0a8a1df18 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Thu, 28 May 2026 20:47:08 -0400 Subject: [PATCH 05/19] Promote gram_eigh_full helpers to matrix operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `pinv_tol`, `sqrt_safe`, and a new `inv_safe` now dispatch on `Diagonal` matrices rather than vectors/scalars. The gram functions no longer call `diag(D)` to pull out the eigenvalues — they operate on `D` directly, which sets up specialization for symmetric/graded diagonal types later. --- src/MatrixAlgebra.jl | 50 +++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 46f3db37..02382b0b 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -24,7 +24,7 @@ export eigen, svdvals!! import MatrixAlgebraKit as MAK -using LinearAlgebra: LinearAlgebra, Diagonal, diag, norm +using LinearAlgebra: LinearAlgebra, Diagonal, norm for (f, f_full, f_compact) in ( (:qr, :qr_full, :qr_compact), @@ -79,29 +79,41 @@ for (eigvals, eigh_vals, eig_vals) in end """ - pinv_tol(λ; atol=0, rtol=...) -> tol - pinv_tol(λ, pinv::NamedTuple) -> tol + pinv_tol(D; atol=0, rtol=...) -> tol + pinv_tol(D, pinv::NamedTuple) -> tol Tolerance used by [`gram_eigh_full`](@ref) and -[`gram_eigh_full_with_pinv`](@ref) to clamp small eigenvalues to zero: -`tol = max(atol, rtol * maximum(abs, λ))`. The `NamedTuple` form splats -its fields as keyword arguments. +[`gram_eigh_full_with_pinv`](@ref) to clamp small eigenvalues of a +diagonal matrix `D` to zero: `tol = max(atol, rtol * maximum(abs, D.diag))`. +The `NamedTuple` form splats its fields as keyword arguments. """ -pinv_tol(λ, pinv::NamedTuple) = pinv_tol(λ; pinv...) +pinv_tol(D::Diagonal, pinv::NamedTuple) = pinv_tol(D; pinv...) function pinv_tol( - λ; atol = zero(real(eltype(λ))), - rtol = iszero(atol) ? eps(real(eltype(λ))) * length(λ) : - zero(real(eltype(λ))) + D::Diagonal; atol = zero(real(eltype(D))), + rtol = iszero(atol) ? eps(real(eltype(D))) * size(D, 1) : + zero(real(eltype(D))) ) - return max(atol, rtol * maximum(abs, λ; init = zero(real(eltype(λ))))) + return max(atol, rtol * maximum(abs, D.diag; init = zero(real(eltype(D))))) end """ - sqrt_safe(a::Number, tol=MatrixAlgebraKit.defaulttol(a)) + sqrt_safe(D::AbstractMatrix, tol=MatrixAlgebraKit.defaulttol(D)) -Compute `sqrt(a)` when `abs(a) ≥ tol`, otherwise return `zero(a)`. +Compute `sqrt(D)`, clamping diagonal entries with `abs < tol` to zero. """ -sqrt_safe(a::Number, tol = MAK.defaulttol(a)) = abs(a) < tol ? zero(a) : sqrt(a) +function sqrt_safe(D::Diagonal, tol = MAK.defaulttol(D)) + return Diagonal(map(d -> abs(d) < tol ? zero(d) : sqrt(d), D.diag)) +end + +""" + inv_safe(D::AbstractMatrix) + +Invert each nonzero diagonal entry of `D`, leaving exact-zero entries +unchanged. Used to form `pinv(sqrt_safe(D, tol))` without re-thresholding. +""" +function inv_safe(D::Diagonal) + return Diagonal(map(d -> iszero(d) ? d : inv(d), D.diag)) +end for (gram, gram_with_pinv, eigh_full) in ( (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), @@ -110,16 +122,12 @@ for (gram, gram_with_pinv, eigh_full) in ( @eval begin function $gram(A::AbstractMatrix; alg = nothing, pinv = (;)) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - λ = diag(D) - sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) - return Diagonal(sqrtλ) * V' + return sqrt_safe(D, pinv_tol(D, pinv)) * V' end function $gram_with_pinv(A::AbstractMatrix; alg = nothing, pinv = (;)) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - λ = diag(D) - sqrtλ = map(l -> sqrt_safe(l, pinv_tol(λ, pinv)), λ) - inv_sqrtλ = map(s -> iszero(s) ? s : inv(s), sqrtλ) - return Diagonal(sqrtλ) * V', V * Diagonal(inv_sqrtλ) + sqrtD = sqrt_safe(D, pinv_tol(D, pinv)) + return sqrtD * V', V * inv_safe(sqrtD) end end end From ca8d5b909cac240ac799e98a483309d68bdcbfbd Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 10:49:11 -0400 Subject: [PATCH 06/19] Refactor gram_eigh_full clamping helpers to pow_safe family Introduce pow_safe(D, p; atol, rtol) as the single hook point for diagonal-like matrices, with sqrt_safe and invsqrt_safe as thin wrappers on top of it. Drop pinv_tol entirely: tolerance computation is now internal to pow_safe and computed from D's singular values. gram_eigh_full(*) take atol and rtol directly instead of bundling them through a pinv NamedTuple. Default rtol is eps^(3//4), matching PEPSKit.jl's sdiag_pow convention. Future extension to graded/block diagonal types requires only a pow_safe specialization; sqrt_safe and invsqrt_safe inherit automatically. --- src/MatrixAlgebra.jl | 85 ++++++++++++++++++++----------------- src/factorizations.jl | 8 ++-- test/test_exports.jl | 3 ++ test/test_factorizations.jl | 8 ++-- test/test_matrixalgebra.jl | 2 +- 5 files changed, 59 insertions(+), 47 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 02382b0b..be2db9a2 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -10,14 +10,17 @@ export eigen, gram_eigh_full!!, gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!!, + invsqrt_safe, lq, lq!!, orth, orth!!, polar, polar!!, + pow_safe, qr, qr!!, + sqrt_safe, svd, svd!!, svdvals, @@ -79,85 +82,91 @@ for (eigvals, eigh_vals, eig_vals) in end """ - pinv_tol(D; atol=0, rtol=...) -> tol - pinv_tol(D, pinv::NamedTuple) -> tol + pow_safe(D, p; atol=0, rtol=...) -> D^p -Tolerance used by [`gram_eigh_full`](@ref) and -[`gram_eigh_full_with_pinv`](@ref) to clamp small eigenvalues of a -diagonal matrix `D` to zero: `tol = max(atol, rtol * maximum(abs, D.diag))`. -The `NamedTuple` form splats its fields as keyword arguments. +Raise an approximately Hermitian positive semi-definite matrix `D` to +the power `p`. Diagonal entries `d` with `abs(d) < tol` are clamped to +zero before exponentiation, where +`tol = max(atol, rtol * maximum(abs, diagview(D)))`. Default +`rtol = eps^(3//4)` (matching PEPSKit's `sdiag_pow` convention). +Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g. +`p = 1//2`) and pass through for integer `p`, so the operation itself +enforces the PSD precondition per-power. + +This is the single hook for diagonal-like types: extending `pow_safe` to +a new type (e.g. graded/block diagonal) automatically extends +[`sqrt_safe`](@ref) and [`invsqrt_safe`](@ref). """ -pinv_tol(D::Diagonal, pinv::NamedTuple) = pinv_tol(D; pinv...) -function pinv_tol( - D::Diagonal; atol = zero(real(eltype(D))), - rtol = iszero(atol) ? eps(real(eltype(D))) * size(D, 1) : +function pow_safe( + D::Diagonal, p; + atol = zero(real(eltype(D))), + rtol = iszero(atol) ? eps(real(eltype(D)))^(3 // 4) : zero(real(eltype(D))) ) - return max(atol, rtol * maximum(abs, D.diag; init = zero(real(eltype(D))))) + σ = D.diag + tol = max(atol, rtol * maximum(abs, σ; init = zero(real(eltype(D))))) + return Diagonal(map(d -> abs(d) < tol ? zero(d) : real(d)^p, σ)) end """ - sqrt_safe(D::AbstractMatrix, tol=MatrixAlgebraKit.defaulttol(D)) + sqrt_safe(D; atol=0, rtol=...) -> D^(1//2) -Compute `sqrt(D)`, clamping diagonal entries with `abs < tol` to zero. +Square root of an approximately Hermitian positive semi-definite matrix +`D`. Equivalent to `pow_safe(D, 1//2; kwargs...)`. """ -function sqrt_safe(D::Diagonal, tol = MAK.defaulttol(D)) - return Diagonal(map(d -> abs(d) < tol ? zero(d) : sqrt(d), D.diag)) -end +sqrt_safe(D; kwargs...) = pow_safe(D, 1 // 2; kwargs...) """ - inv_safe(D::AbstractMatrix) + invsqrt_safe(D; atol=0, rtol=...) -> D^(-1//2) -Invert each nonzero diagonal entry of `D`, leaving exact-zero entries -unchanged. Used to form `pinv(sqrt_safe(D, tol))` without re-thresholding. +Inverse square root of an approximately Hermitian positive semi-definite +matrix `D`, treating eigenvalues below tolerance as zero (Moore-Penrose +convention). Equivalent to `pow_safe(D, -1//2; kwargs...)`. """ -function inv_safe(D::Diagonal) - return Diagonal(map(d -> iszero(d) ? d : inv(d), D.diag)) -end +invsqrt_safe(D; kwargs...) = pow_safe(D, -1 // 2; kwargs...) for (gram, gram_with_pinv, eigh_full) in ( (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), (:gram_eigh_full!!, :gram_eigh_full_with_pinv!!, :eigh_full!), ) @eval begin - function $gram(A::AbstractMatrix; alg = nothing, pinv = (;)) + function $gram(A::AbstractMatrix; alg = nothing, kwargs...) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - return sqrt_safe(D, pinv_tol(D, pinv)) * V' + return sqrt_safe(D; kwargs...) * V' end - function $gram_with_pinv(A::AbstractMatrix; alg = nothing, pinv = (;)) + function $gram_with_pinv(A::AbstractMatrix; alg = nothing, kwargs...) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - sqrtD = sqrt_safe(D, pinv_tol(D, pinv)) - return sqrtD * V', V * inv_safe(sqrtD) + return sqrt_safe(D; kwargs...) * V', V * invsqrt_safe(D; kwargs...) end end end """ - gram_eigh_full(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X - gram_eigh_full!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X + gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X + gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X Gram factorization of a Hermitian positive semi-definite matrix via its -eigendecomposition: returns `X = Diagonal(sqrt.(Λ)) * V'` such that -`A ≈ X' * X`, where `A = V * Diagonal(Λ) * V'`. Eigenvalues below -[`pinv_tol`](@ref) are clamped to zero. The `!!` variant may destroy `A`. +eigendecomposition: returns `X = sqrt_safe(D) * V'` such that +`A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see +[`sqrt_safe`](@ref)) are clamped to zero. The `!!` variant may destroy `A`. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `pinv::NamedTuple`: forwarded to [`pinv_tol`](@ref) (e.g. `(; atol, rtol)`). + - `atol`, `rtol`: forwarded to [`sqrt_safe`](@ref). See also [`gram_eigh_full_with_pinv`](@ref). """ gram_eigh_full, gram_eigh_full!! """ - gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X, Y - gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, pinv=(;)) -> X, Y + gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X, Y + gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X, Y Like [`gram_eigh_full`](@ref), but additionally returns -`Y = V * Diagonal(inv.(sqrt.(Λ))) ≈ pinv(X)` so that `X * Y ≈ I` on the -rank subspace. Eigenvalues below [`pinv_tol`](@ref) are clamped to zero -in both factors. The `!!` variant may destroy `A`. +`Y = V * invsqrt_safe(D) ≈ pinv(X)` so that `X * Y ≈ I` on the rank +subspace. Eigenvalues below `tol` are clamped to zero in both factors. +The `!!` variant may destroy `A`. """ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! diff --git a/src/factorizations.jl b/src/factorizations.jl index 48213bff..0a3425f7 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -448,8 +448,8 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `pinv::NamedTuple`: tolerance options used to clamp small eigenvalues to - zero (see `MatrixAlgebra.pinv_tol`). + - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to + zero (see `MatrixAlgebra.sqrt_safe`). See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. @@ -491,8 +491,8 @@ that `X * Y ≈ I` on the rank subspace. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `pinv::NamedTuple`: tolerance options used to clamp small eigenvalues to - zero in both `X` and `Y` (see `MatrixAlgebra.pinv_tol`). + - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to + zero in both `X` and `Y` (see `MatrixAlgebra.sqrt_safe`). See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ diff --git a/test/test_exports.jl b/test/test_exports.jl index b095c9f7..12617616 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -42,14 +42,17 @@ using Test: @test, @testset :gram_eigh_full!!, :gram_eigh_full_with_pinv, :gram_eigh_full_with_pinv!!, + :invsqrt_safe, :lq, :lq!!, :orth, :orth!!, :polar, :polar!!, + :pow_safe, :qr, :qr!!, + :sqrt_safe, :svd, :svd!!, :svdvals, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 356c3110..03dbeff4 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -368,9 +368,9 @@ end B = randn(T, 4, 2, 3) # k = 4 < codomain dim 6, so A is rank-4 A = contract((:a, :b, :c, :d), conj(B), (:k, :a, :b), B, (:k, :c, :d)) - # Recovery of A is independent of the `pinv` tolerance because all - # nonzero eigenvalues sit far above any reasonable pinv cutoff. - X = gram_eigh_full(A, Val(2); pinv = (; rtol = 1.0e-10)) + # Recovery of A is independent of the `rtol` cutoff because all + # nonzero eigenvalues sit far above any reasonable threshold. + X = gram_eigh_full(A, Val(2); rtol = 1.0e-10) @test A ≈ contract( (:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d) ) @@ -378,7 +378,7 @@ end # Moore–Penrose-like identity: X * Y * X ≈ X when Y is pinv(X). With # rank-first X and rank-last Y, contract X[r, a, b] * Y[a, b, s] → P[r, s] # (projector onto the rank subspace), then P * X → X. - X2, Y2 = gram_eigh_full_with_pinv(A, Val(2); pinv = (; rtol = 1.0e-10)) + X2, Y2 = gram_eigh_full_with_pinv(A, Val(2); rtol = 1.0e-10) P = contract((:r, :s), X2, (:r, :a, :b), Y2, (:a, :b, :s)) PX = contract((:r, :c, :d), P, (:r, :s), X2, (:s, :c, :d)) @test PX ≈ X2 diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 179f2ba5..9dfa1b3e 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -313,7 +313,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) Brd = randn(elt, n, k) Ard = Brd * Brd' Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( - Ard; pinv = (; rtol = sqrt(eps(real(elt)))) + Ard; rtol = sqrt(eps(real(elt))) ) @test Xrd' * Xrd ≈ Ard P = Xrd * Yrd From 93b98ff6f1772ba9edfe1936ce2ca0c8ce7d07e5 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 11:08:28 -0400 Subject: [PATCH 07/19] Split clamping helpers into diag and Hermitian families Two-tier API: - pow_diag_safe(D::Diagonal, p; atol, rtol) is the leaf operation for diagonal-like types, and the specialization hook for graded or block diagonal types (e.g. GradedArrays). sqrt_diag_safe and invsqrt_diag_safe are thin wrappers. - powh_safe(M, p; atol, rtol) is the user-facing operation for approximately Hermitian positive semi-definite matrices. The AbstractMatrix method calls eigh_full(M) then pow_diag_safe directly on the diagonal result (no recursion through powh_safe). A Diagonal overload forwards to pow_diag_safe. sqrth_safe and invsqrth_safe are thin wrappers around powh_safe. gram_eigh_full and gram_eigh_full_with_pinv now use sqrth_safe and invsqrth_safe (which dispatch to pow_diag_safe on the Diagonal output of eigh_full). This avoids the recursion issue of having a single overloaded name and distinguishes the contract of each family by name. --- src/MatrixAlgebra.jl | 91 +++++++++++++++++++++++++++----------- src/factorizations.jl | 4 +- test/test_exports.jl | 9 ++-- test/test_matrixalgebra.jl | 20 +++++++++ 4 files changed, 93 insertions(+), 31 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index be2db9a2..d630ec3c 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -10,17 +10,20 @@ export eigen, gram_eigh_full!!, gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!!, - invsqrt_safe, + invsqrt_diag_safe, + invsqrth_safe, lq, lq!!, orth, orth!!, polar, polar!!, - pow_safe, + pow_diag_safe, + powh_safe, qr, qr!!, - sqrt_safe, + sqrt_diag_safe, + sqrth_safe, svd, svd!!, svdvals, @@ -82,22 +85,22 @@ for (eigvals, eigh_vals, eig_vals) in end """ - pow_safe(D, p; atol=0, rtol=...) -> D^p + pow_diag_safe(D::Diagonal, p; atol=0, rtol=...) -> D^p -Raise an approximately Hermitian positive semi-definite matrix `D` to -the power `p`. Diagonal entries `d` with `abs(d) < tol` are clamped to -zero before exponentiation, where -`tol = max(atol, rtol * maximum(abs, diagview(D)))`. Default +Raise a diagonal matrix `D` to the power `p`. Diagonal entries `d` with +`abs(d) < tol` are clamped to zero before exponentiation, where +`tol = max(atol, rtol * maximum(abs, D.diag))`. Default `rtol = eps^(3//4)` (matching PEPSKit's `sdiag_pow` convention). Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g. `p = 1//2`) and pass through for integer `p`, so the operation itself enforces the PSD precondition per-power. -This is the single hook for diagonal-like types: extending `pow_safe` to -a new type (e.g. graded/block diagonal) automatically extends -[`sqrt_safe`](@ref) and [`invsqrt_safe`](@ref). +This is the leaf operation for diagonal-like types: extending it to a +new diagonal-like type (e.g. graded or block diagonal) automatically +extends [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), and the +[`powh_safe`](@ref) family. """ -function pow_safe( +function pow_diag_safe( D::Diagonal, p; atol = zero(real(eltype(D))), rtol = iszero(atol) ? eps(real(eltype(D)))^(3 // 4) : @@ -109,21 +112,56 @@ function pow_safe( end """ - sqrt_safe(D; atol=0, rtol=...) -> D^(1//2) + sqrt_diag_safe(D; atol=0, rtol=...) -> D^(1//2) -Square root of an approximately Hermitian positive semi-definite matrix -`D`. Equivalent to `pow_safe(D, 1//2; kwargs...)`. +Square root of a diagonal matrix `D`, equivalent to +`pow_diag_safe(D, 1//2; kwargs...)`. """ -sqrt_safe(D; kwargs...) = pow_safe(D, 1 // 2; kwargs...) +sqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) """ - invsqrt_safe(D; atol=0, rtol=...) -> D^(-1//2) + invsqrt_diag_safe(D; atol=0, rtol=...) -> D^(-1//2) + +Inverse square root of a diagonal matrix `D`, treating diagonal entries +below tolerance as zero (Moore-Penrose convention). Equivalent to +`pow_diag_safe(D, -1//2; kwargs...)`. +""" +invsqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) + +""" + powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=...) -> M^p + powh_safe(D::Diagonal, p; atol=0, rtol=...) -> D^p + +Raise an approximately Hermitian positive semi-definite matrix to the +power `p`. For a general `M`, this is computed via the eigendecomposition +`M = V * D * V'` as `V * powh_safe(D, p) * V'`. For a `Diagonal` input, +this dispatches to [`pow_diag_safe`](@ref). + +See [`pow_diag_safe`](@ref) for tolerance semantics and the +specialization hook. +""" +powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) + +function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...) + D, V = MAK.eigh_full(M, MAK.select_algorithm(MAK.eigh_full, M, alg)) + return V * pow_diag_safe(D, p; kwargs...) * V' +end + +""" + sqrth_safe(M; alg=nothing, atol=0, rtol=...) -> M^(1//2) + +Square root of an approximately Hermitian positive semi-definite matrix. +Equivalent to `powh_safe(M, 1//2; kwargs...)`. +""" +sqrth_safe(M; kwargs...) = powh_safe(M, 1 // 2; kwargs...) + +""" + invsqrth_safe(M; alg=nothing, atol=0, rtol=...) -> M^(-1//2) Inverse square root of an approximately Hermitian positive semi-definite -matrix `D`, treating eigenvalues below tolerance as zero (Moore-Penrose -convention). Equivalent to `pow_safe(D, -1//2; kwargs...)`. +matrix. Equivalent to `powh_safe(M, -1//2; kwargs...)`. """ -invsqrt_safe(D; kwargs...) = pow_safe(D, -1 // 2; kwargs...) +invsqrth_safe(M; kwargs...) = powh_safe(M, -1 // 2; kwargs...) for (gram, gram_with_pinv, eigh_full) in ( (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), @@ -132,11 +170,11 @@ for (gram, gram_with_pinv, eigh_full) in ( @eval begin function $gram(A::AbstractMatrix; alg = nothing, kwargs...) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - return sqrt_safe(D; kwargs...) * V' + return sqrth_safe(D; kwargs...) * V' end function $gram_with_pinv(A::AbstractMatrix; alg = nothing, kwargs...) D, V = MAK.$eigh_full(A, MAK.select_algorithm(MAK.$eigh_full, A, alg)) - return sqrt_safe(D; kwargs...) * V', V * invsqrt_safe(D; kwargs...) + return sqrth_safe(D; kwargs...) * V', V * invsqrth_safe(D; kwargs...) end end end @@ -146,14 +184,15 @@ end gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X Gram factorization of a Hermitian positive semi-definite matrix via its -eigendecomposition: returns `X = sqrt_safe(D) * V'` such that +eigendecomposition: returns `X = sqrth_safe(D) * V'` such that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see -[`sqrt_safe`](@ref)) are clamped to zero. The `!!` variant may destroy `A`. +[`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may +destroy `A`. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `atol`, `rtol`: forwarded to [`sqrt_safe`](@ref). + - `atol`, `rtol`: forwarded to [`pow_diag_safe`](@ref). See also [`gram_eigh_full_with_pinv`](@ref). """ @@ -164,7 +203,7 @@ gram_eigh_full, gram_eigh_full!! gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X, Y Like [`gram_eigh_full`](@ref), but additionally returns -`Y = V * invsqrt_safe(D) ≈ pinv(X)` so that `X * Y ≈ I` on the rank +`Y = V * invsqrth_safe(D) ≈ pinv(X)` so that `X * Y ≈ I` on the rank subspace. Eigenvalues below `tol` are clamped to zero in both factors. The `!!` variant may destroy `A`. """ diff --git a/src/factorizations.jl b/src/factorizations.jl index 0a3425f7..a70091c6 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -449,7 +449,7 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to - zero (see `MatrixAlgebra.sqrt_safe`). + zero (see `MatrixAlgebra.pow_diag_safe`). See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. @@ -492,7 +492,7 @@ that `X * Y ≈ I` on the rank subspace. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to - zero in both `X` and `Y` (see `MatrixAlgebra.sqrt_safe`). + zero in both `X` and `Y` (see `MatrixAlgebra.pow_diag_safe`). See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ diff --git a/test/test_exports.jl b/test/test_exports.jl index 12617616..5968127f 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -42,17 +42,20 @@ using Test: @test, @testset :gram_eigh_full!!, :gram_eigh_full_with_pinv, :gram_eigh_full_with_pinv!!, - :invsqrt_safe, + :invsqrt_diag_safe, + :invsqrth_safe, :lq, :lq!!, :orth, :orth!!, :polar, :polar!!, - :pow_safe, + :pow_diag_safe, + :powh_safe, :qr, :qr!!, - :sqrt_safe, + :sqrt_diag_safe, + :sqrth_safe, :svd, :svd!!, :svdvals, diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 9dfa1b3e..a6c869bb 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -320,4 +320,24 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test P * P ≈ P @test P * Xrd ≈ Xrd end + + @testset "powh_safe / sqrth_safe / invsqrth_safe" begin + n = 4 + B = randn(elt, n, n) + A = B * B' + sqrtA = MatrixAlgebra.sqrth_safe(A) + @test sqrtA * sqrtA ≈ A + @test sqrtA ≈ sqrtA' + + invsqrtA = MatrixAlgebra.invsqrth_safe(A) + @test invsqrtA * sqrtA ≈ I(n) + + # Integer power: passes through without clamping affecting result. + @test MatrixAlgebra.powh_safe(A, 2) ≈ A * A + + # Diagonal-input dispatch path. + D = Diagonal(rand(real(elt), n)) + @test MatrixAlgebra.sqrth_safe(D) ≈ + MatrixAlgebra.pow_diag_safe(D, 1 // 2) + end end From 1b6335f93bf8b2aef231c1e11805a4f5da0ac5d3 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 11:27:53 -0400 Subject: [PATCH 08/19] Use copy instead of deepcopy in test_factorizations.jl The tests only build plain Float64/ComplexF64 arrays via randn, so deepcopy was equivalent to copy. Simplify to the cheaper primitive. --- test/test_factorizations.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 03dbeff4..2a353aca 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -16,7 +16,7 @@ elts = (Float64, ComplexF64) labels_Q = (:b, :a) labels_R = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = true) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) @@ -36,7 +36,7 @@ end labels_Q = (:b, :a) labels_R = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) Q, R = @constinferred qr(A, labels_A, labels_Q, labels_R; full = false) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, Q, (labels_Q..., :q), R, (:q, labels_R...)) @@ -52,7 +52,7 @@ end labels_Q = (:d, :c) labels_L = (:b, :a) - Acopy = deepcopy(A) + Acopy = copy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = true) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) @@ -69,7 +69,7 @@ end labels_Q = (:d, :c) labels_L = (:b, :a) - Acopy = deepcopy(A) + Acopy = copy(A) L, Q = @constinferred lq(A, labels_A, labels_L, labels_Q; full = false) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, L, (labels_L..., :q), Q, (:q, labels_Q...)) @@ -85,7 +85,7 @@ end labels_V = (:b, :a) labels_V′ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) # type-unstable because of `ishermitian` difference D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = false) @test A == Acopy # should not have altered initial array @@ -108,7 +108,7 @@ end labels_V = (:b, :a) labels_V′ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) # type-unstable because of `ishermitian` difference D, V = eigen(A, labels_A, labels_V, labels_V′; ishermitian = true) @test A == Acopy # should not have altered initial array @@ -133,7 +133,7 @@ end labels_U = (:b, :a) labels_Vᴴ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(true)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) @@ -167,7 +167,7 @@ end labels_U = (:b, :a) labels_Vᴴ = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) U, S, Vᴴ = @constinferred svd(A, labels_A, labels_U, labels_Vᴴ; full = Val(false)) @test A == Acopy # should not have altered initial array US, labels_US = contract(U, (labels_U..., :u), S, (:u, :v)) @@ -201,7 +201,7 @@ end labels_Vᴴ = (:d, :c) # test truncated SVD - Acopy = deepcopy(A) + Acopy = copy(A) _, S_untrunc, _ = svd(A, labels_A, labels_U, labels_Vᴴ) trunc = truncrank(size(S_untrunc, 1) - 1) @@ -220,7 +220,7 @@ end labels_codomain = (:b, :a) labels_domain = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) N = @constinferred left_null(A, labels_A, labels_codomain, labels_domain) @test A == Acopy # should not have altered initial array # N^ba_n' * A^ba_dc = 0 @@ -244,7 +244,7 @@ end labels_W = (:b, :a) labels_P = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (W, P) in ( left_polar(A, labels_A, labels_W, labels_P), polar(A, labels_A, labels_W, labels_P; side = :left), @@ -263,7 +263,7 @@ end labels_P = (:b, :a) labels_W = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (P, W) in ( right_polar(A, labels_A, labels_P, labels_W), polar(A, labels_A, labels_P, labels_W; side = :right), @@ -281,7 +281,7 @@ end labels_W = (:b, :a) labels_P = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (W, P) in ( left_orth(A, labels_A, labels_W, labels_P), orth(A, labels_A, labels_W, labels_P; side = :left), @@ -300,7 +300,7 @@ end labels_P = (:b, :a) labels_W = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for (P, W) in ( right_orth(A, labels_A, labels_P, labels_W), orth(A, labels_A, labels_P, labels_W; side = :right), @@ -318,7 +318,7 @@ end labels_X = (:b, :a) labels_Y = (:d, :c) - Acopy = deepcopy(A) + Acopy = copy(A) for orth in (:left, :right) X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) @test A == Acopy # should not have altered initial array @@ -345,7 +345,7 @@ end labels_X = (:a, :b) labels_Y = (:c, :d) - Acopy = deepcopy(A) + Acopy = copy(A) X = @constinferred gram_eigh_full(A, labels_A, labels_X, labels_Y) @test A == Acopy # should not have altered initial array A′ = contract(labels_A, conj(X), (:r, :a, :b), X, (:r, :c, :d)) From 13937463a48b11596d616096edafb4312dc12cf0 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 11:34:45 -0400 Subject: [PATCH 09/19] Document atol/rtol in all clamping-family docstrings - Signature lines now show `atol, rtol` with concrete default formulas. - Body text shows `sqrth_safe(D; atol, rtol)` etc. rather than the bare call. - Each docstring's '## Keyword arguments' section explicitly states the defaults for `atol` and `rtol`, sharing a single source of truth via an interpolated `_CLAMP_KWARGS_DOC` constant. --- src/MatrixAlgebra.jl | 98 ++++++++++++++++++++++++++++++------------- src/factorizations.jl | 6 +-- 2 files changed, 70 insertions(+), 34 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index d630ec3c..d7c30cea 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -85,20 +85,30 @@ for (eigvals, eigh_vals, eig_vals) in end """ - pow_diag_safe(D::Diagonal, p; atol=0, rtol=...) -> D^p +Shared documentation for the `atol` and `rtol` keyword arguments of the +`pow_diag_safe` / `powh_safe` family. +""" +const _CLAMP_KWARGS_DOC = """ - `atol::Real`: absolute clamping threshold. Default `0`. +- `rtol::Real`: relative clamping threshold. Default `eps(real(eltype(D)))^(3//4)` when `atol = 0`, else `0`.""" + +""" + pow_diag_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p Raise a diagonal matrix `D` to the power `p`. Diagonal entries `d` with `abs(d) < tol` are clamped to zero before exponentiation, where -`tol = max(atol, rtol * maximum(abs, D.diag))`. Default -`rtol = eps^(3//4)` (matching PEPSKit's `sdiag_pow` convention). -Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g. -`p = 1//2`) and pass through for integer `p`, so the operation itself -enforces the PSD precondition per-power. +`tol = max(atol, rtol * maximum(abs, D.diag))`. Negative `d` above `tol` +cause `d^p` to error for fractional `p` (e.g. `p = 1//2`) and pass +through for integer `p`, so the operation itself enforces the PSD +precondition per-power. This is the leaf operation for diagonal-like types: extending it to a new diagonal-like type (e.g. graded or block diagonal) automatically extends [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), and the [`powh_safe`](@ref) family. + +## Keyword arguments + +$(_CLAMP_KWARGS_DOC) """ function pow_diag_safe( D::Diagonal, p; @@ -112,33 +122,44 @@ function pow_diag_safe( end """ - sqrt_diag_safe(D; atol=0, rtol=...) -> D^(1//2) + sqrt_diag_safe(D; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) Square root of a diagonal matrix `D`, equivalent to -`pow_diag_safe(D, 1//2; kwargs...)`. +`pow_diag_safe(D, 1//2; atol, rtol)`. + +## Keyword arguments + +$(_CLAMP_KWARGS_DOC) """ sqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) """ - invsqrt_diag_safe(D; atol=0, rtol=...) -> D^(-1//2) + invsqrt_diag_safe(D; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) Inverse square root of a diagonal matrix `D`, treating diagonal entries below tolerance as zero (Moore-Penrose convention). Equivalent to -`pow_diag_safe(D, -1//2; kwargs...)`. +`pow_diag_safe(D, -1//2; atol, rtol)`. + +## Keyword arguments + +$(_CLAMP_KWARGS_DOC) """ invsqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) """ - powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=...) -> M^p - powh_safe(D::Diagonal, p; atol=0, rtol=...) -> D^p + powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p + powh_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p Raise an approximately Hermitian positive semi-definite matrix to the power `p`. For a general `M`, this is computed via the eigendecomposition -`M = V * D * V'` as `V * powh_safe(D, p) * V'`. For a `Diagonal` input, -this dispatches to [`pow_diag_safe`](@ref). +`M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`. For a +`Diagonal` input, this dispatches to [`pow_diag_safe`](@ref). -See [`pow_diag_safe`](@ref) for tolerance semantics and the -specialization hook. +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when + `M` is non-diagonal). + $(_CLAMP_KWARGS_DOC) """ powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) @@ -148,18 +169,30 @@ function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...) end """ - sqrth_safe(M; alg=nothing, atol=0, rtol=...) -> M^(1//2) + sqrth_safe(M; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2) Square root of an approximately Hermitian positive semi-definite matrix. -Equivalent to `powh_safe(M, 1//2; kwargs...)`. +Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when + `M` is non-diagonal). + $(_CLAMP_KWARGS_DOC) """ sqrth_safe(M; kwargs...) = powh_safe(M, 1 // 2; kwargs...) """ - invsqrth_safe(M; alg=nothing, atol=0, rtol=...) -> M^(-1//2) + invsqrth_safe(M; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2) Inverse square root of an approximately Hermitian positive semi-definite -matrix. Equivalent to `powh_safe(M, -1//2; kwargs...)`. +matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when + `M` is non-diagonal). + $(_CLAMP_KWARGS_DOC) """ invsqrth_safe(M; kwargs...) = powh_safe(M, -1 // 2; kwargs...) @@ -180,32 +213,37 @@ for (gram, gram_with_pinv, eigh_full) in ( end """ - gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X - gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X + gram_eigh_full(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X + gram_eigh_full!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X Gram factorization of a Hermitian positive semi-definite matrix via its -eigendecomposition: returns `X = sqrth_safe(D) * V'` such that -`A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see +eigendecomposition: returns `X = sqrth_safe(D; atol, rtol) * V'` such +that `A ≈ X' * X`, where `A = V * D * V'`. Eigenvalues below `tol` (see [`pow_diag_safe`](@ref)) are clamped to zero. The `!!` variant may destroy `A`. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `atol`, `rtol`: forwarded to [`pow_diag_safe`](@ref). + $(_CLAMP_KWARGS_DOC) See also [`gram_eigh_full_with_pinv`](@ref). """ gram_eigh_full, gram_eigh_full!! """ - gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X, Y - gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=...) -> X, Y + gram_eigh_full_with_pinv(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y + gram_eigh_full_with_pinv!!(A::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(A)))^(3//4)) -> X, Y Like [`gram_eigh_full`](@ref), but additionally returns -`Y = V * invsqrth_safe(D) ≈ pinv(X)` so that `X * Y ≈ I` on the rank -subspace. Eigenvalues below `tol` are clamped to zero in both factors. -The `!!` variant may destroy `A`. +`Y = V * invsqrth_safe(D; atol, rtol) ≈ pinv(X)` so that `X * Y ≈ I` on +the rank subspace. Eigenvalues below `tol` are clamped to zero in both +factors. The `!!` variant may destroy `A`. + +## Keyword arguments + + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. + $(_CLAMP_KWARGS_DOC) """ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! diff --git a/src/factorizations.jl b/src/factorizations.jl index a70091c6..8888fb72 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -448,8 +448,7 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to - zero (see `MatrixAlgebra.pow_diag_safe`). + $(MatrixAlgebra._CLAMP_KWARGS_DOC) See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. @@ -491,8 +490,7 @@ that `X * Y ≈ I` on the rank subspace. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - - `atol`, `rtol`: tolerance options used to clamp small eigenvalues to - zero in both `X` and `Y` (see `MatrixAlgebra.pow_diag_safe`). + $(MatrixAlgebra._CLAMP_KWARGS_DOC) See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ From 44b05d2c39d732293782416f73917ebf4af7fc5a Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 11:38:23 -0400 Subject: [PATCH 10/19] Fix interpolated docstring rendering for clamping kwargs The formatter was reindenting the const definition's continuation line and indenting `$(_CLAMP_KWARGS_DOC)` as a list continuation after the `alg` bullet, which made atol appear as a sub-bullet of alg and rtol disconnected. Define the const via `join` of a tuple of single-line strings so the formatter cannot mangle line breaks, and insert a blank line before `$(_CLAMP_KWARGS_DOC)` in docstrings so it is not treated as a list continuation. --- src/MatrixAlgebra.jl | 23 ++++++++++++++++------- src/factorizations.jl | 6 ++++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index d7c30cea..28ca40d9 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -88,8 +88,12 @@ end Shared documentation for the `atol` and `rtol` keyword arguments of the `pow_diag_safe` / `powh_safe` family. """ -const _CLAMP_KWARGS_DOC = """ - `atol::Real`: absolute clamping threshold. Default `0`. -- `rtol::Real`: relative clamping threshold. Default `eps(real(eltype(D)))^(3//4)` when `atol = 0`, else `0`.""" +const _CLAMP_KWARGS_DOC = join( + ( + " - `atol::Real`: absolute clamping threshold. Default `0`.", + " - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype(D)))^(3//4)` when `atol = 0`, else `0`.", + ), "\n" +) """ pow_diag_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p @@ -159,7 +163,8 @@ power `p`. For a general `M`, this is computed via the eigendecomposition - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). - $(_CLAMP_KWARGS_DOC) + +$(_CLAMP_KWARGS_DOC) """ powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) @@ -178,7 +183,8 @@ Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). - $(_CLAMP_KWARGS_DOC) + +$(_CLAMP_KWARGS_DOC) """ sqrth_safe(M; kwargs...) = powh_safe(M, 1 // 2; kwargs...) @@ -192,7 +198,8 @@ matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). - $(_CLAMP_KWARGS_DOC) + +$(_CLAMP_KWARGS_DOC) """ invsqrth_safe(M; kwargs...) = powh_safe(M, -1 // 2; kwargs...) @@ -225,7 +232,8 @@ destroy `A`. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - $(_CLAMP_KWARGS_DOC) + +$(_CLAMP_KWARGS_DOC) See also [`gram_eigh_full_with_pinv`](@ref). """ @@ -243,7 +251,8 @@ factors. The `!!` variant may destroy `A`. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - $(_CLAMP_KWARGS_DOC) + +$(_CLAMP_KWARGS_DOC) """ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! diff --git a/src/factorizations.jl b/src/factorizations.jl index 8888fb72..248296c3 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -448,7 +448,8 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - $(MatrixAlgebra._CLAMP_KWARGS_DOC) + +$(MatrixAlgebra._CLAMP_KWARGS_DOC) See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. @@ -490,7 +491,8 @@ that `X * Y ≈ I` on the rank subspace. ## Keyword arguments - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. - $(MatrixAlgebra._CLAMP_KWARGS_DOC) + +$(MatrixAlgebra._CLAMP_KWARGS_DOC) See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ From 27acfc7609699ef58be1e9f9602fc280153abb92 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 11:42:26 -0400 Subject: [PATCH 11/19] Make clamping-kwargs doc a function over the variable name The default rtol formula reads against a specific variable in each docstring (D for pow_diag_safe, M for powh_safe, A for gram_eigh_full). Convert the shared _CLAMP_KWARGS_DOC constant into a _clamp_kwargs_doc function that takes the variable name as an argument, so each docstring gets the right name interpolated. --- src/MatrixAlgebra.jl | 36 +++++++++++++++++++++--------------- src/factorizations.jl | 4 ++-- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 28ca40d9..7dcae14d 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -85,15 +85,21 @@ for (eigvals, eigh_vals, eig_vals) in end """ + _clamp_kwargs_doc(arg::AbstractString) + Shared documentation for the `atol` and `rtol` keyword arguments of the -`pow_diag_safe` / `powh_safe` family. +`pow_diag_safe` / `powh_safe` family. `arg` is the name of the matrix +argument used in the signatures of the host docstring, so the default +`rtol` formula reads against the right variable. """ -const _CLAMP_KWARGS_DOC = join( - ( - " - `atol::Real`: absolute clamping threshold. Default `0`.", - " - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype(D)))^(3//4)` when `atol = 0`, else `0`.", - ), "\n" -) +function _clamp_kwargs_doc(arg::AbstractString) + return join( + ( + " - `atol::Real`: absolute clamping threshold. Default `0`.", + " - `rtol::Real`: relative clamping threshold. Default `eps(real(eltype($arg)))^(3//4)` when `atol = 0`, else `0`.", + ), "\n" + ) +end """ pow_diag_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p @@ -112,7 +118,7 @@ extends [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), and the ## Keyword arguments -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("D")) """ function pow_diag_safe( D::Diagonal, p; @@ -133,7 +139,7 @@ Square root of a diagonal matrix `D`, equivalent to ## Keyword arguments -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("D")) """ sqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) @@ -146,7 +152,7 @@ below tolerance as zero (Moore-Penrose convention). Equivalent to ## Keyword arguments -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("D")) """ invsqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) @@ -164,7 +170,7 @@ power `p`. For a general `M`, this is computed via the eigendecomposition - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("M")) """ powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) @@ -184,7 +190,7 @@ Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("M")) """ sqrth_safe(M; kwargs...) = powh_safe(M, 1 // 2; kwargs...) @@ -199,7 +205,7 @@ matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when `M` is non-diagonal). -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("M")) """ invsqrth_safe(M; kwargs...) = powh_safe(M, -1 // 2; kwargs...) @@ -233,7 +239,7 @@ destroy `A`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("A")) See also [`gram_eigh_full_with_pinv`](@ref). """ @@ -252,7 +258,7 @@ factors. The `!!` variant may destroy `A`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. -$(_CLAMP_KWARGS_DOC) +$(_clamp_kwargs_doc("A")) """ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! diff --git a/src/factorizations.jl b/src/factorizations.jl index 248296c3..cf48fbb4 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -449,7 +449,7 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. -$(MatrixAlgebra._CLAMP_KWARGS_DOC) +$(MatrixAlgebra._clamp_kwargs_doc("A")) See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. @@ -492,7 +492,7 @@ that `X * Y ≈ I` on the rank subspace. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. -$(MatrixAlgebra._CLAMP_KWARGS_DOC) +$(MatrixAlgebra._clamp_kwargs_doc("A")) See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ From 85be5f84048aad6169186049a8c2b2864b10e222 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 12:02:41 -0400 Subject: [PATCH 12/19] Add jldoctest examples to gram_eigh_full entries Cover the two user-facing entry points at both matrix and tensor layer with executable examples. Verified via Documenter.doctest. --- src/MatrixAlgebra.jl | 33 +++++++++++++++++++++++++++++++++ src/factorizations.jl | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 7dcae14d..2d70161e 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -241,6 +241,20 @@ destroy `A`. $(_clamp_kwargs_doc("A")) +# Examples + +```jldoctest +julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full + +julia> B = [1.0 0.5; 0.5 2.0]; + A = B * B'; + +julia> X = gram_eigh_full(A); + +julia> X' * X ≈ A +true +``` + See also [`gram_eigh_full_with_pinv`](@ref). """ gram_eigh_full, gram_eigh_full!! @@ -259,6 +273,25 @@ factors. The `!!` variant may destroy `A`. - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. $(_clamp_kwargs_doc("A")) + +# Examples + +```jldoctest +julia> using LinearAlgebra: I + +julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv + +julia> B = [1.0 0.5; 0.5 2.0]; + A = B * B'; + +julia> X, Y = gram_eigh_full_with_pinv(A); + +julia> X' * X ≈ A +true + +julia> X * Y ≈ I +true +``` """ gram_eigh_full_with_pinv, gram_eigh_full_with_pinv!! diff --git a/src/factorizations.jl b/src/factorizations.jl index cf48fbb4..48c084a2 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -451,6 +451,20 @@ dimensions. Returns `X` such that `A ≈ X' * X` (contracted on the rank leg). $(MatrixAlgebra._clamp_kwargs_doc("A")) +# Examples + +```jldoctest +julia> using TensorAlgebra: contract, gram_eigh_full + +julia> B = randn(3, 2, 2); + A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> X = gram_eigh_full(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); + +julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d)) +true +``` + See also [`gram_eigh_full_with_pinv`](@ref) and `MatrixAlgebra.gram_eigh_full`. """ @@ -494,6 +508,25 @@ that `X * Y ≈ I` on the rank subspace. $(MatrixAlgebra._clamp_kwargs_doc("A")) +# Examples + +```jldoctest +julia> using LinearAlgebra: I + +julia> using TensorAlgebra: contract, gram_eigh_full_with_pinv + +julia> B = randn(8, 2, 2); + A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> X, Y = gram_eigh_full_with_pinv(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); + +julia> A ≈ contract((:a, :b, :c, :d), conj(X), (:r, :a, :b), X, (:r, :c, :d)) +true + +julia> contract((:r, :s), X, (:r, :a, :b), Y, (:a, :b, :s)) ≈ I +true +``` + See also `MatrixAlgebra.gram_eigh_full_with_pinv`. """ gram_eigh_full_with_pinv From 270e37ec45f2dc6b7f3d894e0abe2fe9b42f6118 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 15:14:00 -0400 Subject: [PATCH 13/19] Split joined julia> blocks in gram doctest examples Each setup statement gets its own julia> prompt for readability. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/MatrixAlgebra.jl | 6 ++++-- src/factorizations.jl | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 2d70161e..562d878a 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -247,7 +247,8 @@ $(_clamp_kwargs_doc("A")) julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full julia> B = [1.0 0.5; 0.5 2.0]; - A = B * B'; + +julia> A = B * B'; julia> X = gram_eigh_full(A); @@ -282,7 +283,8 @@ julia> using LinearAlgebra: I julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv julia> B = [1.0 0.5; 0.5 2.0]; - A = B * B'; + +julia> A = B * B'; julia> X, Y = gram_eigh_full_with_pinv(A); diff --git a/src/factorizations.jl b/src/factorizations.jl index 48c084a2..d2cc30af 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -457,7 +457,8 @@ $(MatrixAlgebra._clamp_kwargs_doc("A")) julia> using TensorAlgebra: contract, gram_eigh_full julia> B = randn(3, 2, 2); - A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); julia> X = gram_eigh_full(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); @@ -516,7 +517,8 @@ julia> using LinearAlgebra: I julia> using TensorAlgebra: contract, gram_eigh_full_with_pinv julia> B = randn(8, 2, 2); - A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); + +julia> A = contract((:a, :b, :c, :d), conj(B), (:r, :a, :b), B, (:r, :c, :d)); julia> X, Y = gram_eigh_full_with_pinv(A, (:a, :b, :c, :d), (:a, :b), (:c, :d)); From ceeeaf9fc2ea1f2faabc838b60470557401dab8d Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 15:37:47 -0400 Subject: [PATCH 14/19] Fix docstring attachments and signatures in MatrixAlgebra helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move powh_safe docstring to the generic AbstractMatrix method; the Diagonal specialization (now below the generic) inherits via the function-level docs. - Add explicit ::Diagonal / ::AbstractMatrix type annotations to sqrt_diag_safe, invsqrt_diag_safe, sqrth_safe, invsqrth_safe so the method signatures match the documented signatures. - Stop silently absorbing `alg` on the Diagonal powh_safe path: passing an incompatible `alg` now errors via pow_diag_safe's unrecognized-kwarg, matching MAK's convention of erroring on incompatible algorithms rather than ignoring them. - Switch `A = B * B'` to `A = B' * B` in the matrix-layer jldoctest examples for consistency with the rank-first convention `A ≈ X' * X`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/MatrixAlgebra.jl | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 562d878a..e78c5fa9 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -132,7 +132,7 @@ function pow_diag_safe( end """ - sqrt_diag_safe(D; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) + sqrt_diag_safe(D::Diagonal; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) Square root of a diagonal matrix `D`, equivalent to `pow_diag_safe(D, 1//2; atol, rtol)`. @@ -141,10 +141,10 @@ Square root of a diagonal matrix `D`, equivalent to $(_clamp_kwargs_doc("D")) """ -sqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) +sqrt_diag_safe(D::Diagonal; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) """ - invsqrt_diag_safe(D; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) + invsqrt_diag_safe(D::Diagonal; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) Inverse square root of a diagonal matrix `D`, treating diagonal entries below tolerance as zero (Moore-Penrose convention). Equivalent to @@ -154,60 +154,55 @@ below tolerance as zero (Moore-Penrose convention). Equivalent to $(_clamp_kwargs_doc("D")) """ -invsqrt_diag_safe(D; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) +invsqrt_diag_safe(D::Diagonal; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) """ powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p - powh_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p Raise an approximately Hermitian positive semi-definite matrix to the -power `p`. For a general `M`, this is computed via the eigendecomposition -`M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`. For a -`Diagonal` input, this dispatches to [`pow_diag_safe`](@ref). +power `p`. Computed via the eigendecomposition `M = V * D * V'` as +`V * pow_diag_safe(D, p; atol, rtol) * V'`. ## Keyword arguments - - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when - `M` is non-diagonal). + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. $(_clamp_kwargs_doc("M")) """ -powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) - function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...) D, V = MAK.eigh_full(M, MAK.select_algorithm(MAK.eigh_full, M, alg)) return V * pow_diag_safe(D, p; kwargs...) * V' end +powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) + """ - sqrth_safe(M; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2) + sqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2) Square root of an approximately Hermitian positive semi-definite matrix. Equivalent to `powh_safe(M, 1//2; alg, atol, rtol)`. ## Keyword arguments - - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when - `M` is non-diagonal). + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. $(_clamp_kwargs_doc("M")) """ -sqrth_safe(M; kwargs...) = powh_safe(M, 1 // 2; kwargs...) +sqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, 1 // 2; kwargs...) """ - invsqrth_safe(M; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2) + invsqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(-1//2) Inverse square root of an approximately Hermitian positive semi-definite matrix. Equivalent to `powh_safe(M, -1//2; alg, atol, rtol)`. ## Keyword arguments - - `alg`: forwarded to `MatrixAlgebraKit.eigh_full` (only used when - `M` is non-diagonal). + - `alg`: forwarded to `MatrixAlgebraKit.eigh_full`. $(_clamp_kwargs_doc("M")) """ -invsqrth_safe(M; kwargs...) = powh_safe(M, -1 // 2; kwargs...) +invsqrth_safe(M::AbstractMatrix; kwargs...) = powh_safe(M, -1 // 2; kwargs...) for (gram, gram_with_pinv, eigh_full) in ( (:gram_eigh_full, :gram_eigh_full_with_pinv, :eigh_full), @@ -248,7 +243,7 @@ julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full julia> B = [1.0 0.5; 0.5 2.0]; -julia> A = B * B'; +julia> A = B' * B; julia> X = gram_eigh_full(A); @@ -284,7 +279,7 @@ julia> using TensorAlgebra.MatrixAlgebra: gram_eigh_full_with_pinv julia> B = [1.0 0.5; 0.5 2.0]; -julia> A = B * B'; +julia> A = B' * B; julia> X, Y = gram_eigh_full_with_pinv(A); From 72d5a67a9368e575661a46af08c9729b1f4d5101 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 15:46:09 -0400 Subject: [PATCH 15/19] Remove docstring from internal _clamp_kwargs_doc helper Internal underscore-prefixed helpers shouldn't carry docstrings. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/MatrixAlgebra.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index e78c5fa9..1934f4f2 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -84,14 +84,6 @@ for (eigvals, eigh_vals, eig_vals) in end end -""" - _clamp_kwargs_doc(arg::AbstractString) - -Shared documentation for the `atol` and `rtol` keyword arguments of the -`pow_diag_safe` / `powh_safe` family. `arg` is the name of the matrix -argument used in the signatures of the host docstring, so the default -`rtol` formula reads against the right variable. -""" function _clamp_kwargs_doc(arg::AbstractString) return join( ( From 4bc5fd9b2167b8530a8b823be3b35f10ef92120b Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 16:04:54 -0400 Subject: [PATCH 16/19] Generalize pow_diag_safe to AbstractMatrix via MAK.diagview/diagonal pow_diag_safe, sqrt_diag_safe, invsqrt_diag_safe now take any AbstractMatrix, extracting entries through MAK.diagview and rebuilding through MAK.diagonal. An isdiag(D) guard rejects non-diagonal-structured inputs with a clear ArgumentError. Types extending MAK.diagview and MAK.diagonal (e.g. graded or block-diagonal matrices) now inherit the whole family without writing their own pow_diag_safe method. powh_safe collapses to a single AbstractMatrix method that takes the isdiag(M) fast-path before falling through to eigh_full; the explicit Diagonal specialization is no longer needed. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/MatrixAlgebra.jl | 55 +++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 1934f4f2..45e4c78c 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -30,7 +30,7 @@ export eigen, svdvals!! import MatrixAlgebraKit as MAK -using LinearAlgebra: LinearAlgebra, Diagonal, norm +using LinearAlgebra: LinearAlgebra, Diagonal, isdiag, norm for (f, f_full, f_compact) in ( (:qr, :qr_full, :qr_compact), @@ -94,66 +94,70 @@ function _clamp_kwargs_doc(arg::AbstractString) end """ - pow_diag_safe(D::Diagonal, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p + pow_diag_safe(D::AbstractMatrix, p; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^p -Raise a diagonal matrix `D` to the power `p`. Diagonal entries `d` with -`abs(d) < tol` are clamped to zero before exponentiation, where -`tol = max(atol, rtol * maximum(abs, D.diag))`. Negative `d` above `tol` -cause `d^p` to error for fractional `p` (e.g. `p = 1//2`) and pass -through for integer `p`, so the operation itself enforces the PSD -precondition per-power. +Raise a diagonal-structured matrix `D` to the power `p`. Diagonal entries +`d` of `MAK.diagview(D)` with `abs(d) < tol` are clamped to zero before +exponentiation, where `tol = max(atol, rtol * maximum(abs, diagview(D)))`. +Negative `d` above `tol` cause `d^p` to error for fractional `p` (e.g. +`p = 1//2`) and pass through for integer `p`, so the operation itself +enforces the PSD precondition per-power. Errors if `isdiag(D)` is `false`. -This is the leaf operation for diagonal-like types: extending it to a -new diagonal-like type (e.g. graded or block diagonal) automatically -extends [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), and the -[`powh_safe`](@ref) family. +The implementation extracts entries via `MAK.diagview` and rebuilds via +`MAK.diagonal`, so types extending those (e.g. graded or block diagonal) +automatically extend [`sqrt_diag_safe`](@ref), [`invsqrt_diag_safe`](@ref), +and the [`powh_safe`](@ref) family. ## Keyword arguments $(_clamp_kwargs_doc("D")) """ function pow_diag_safe( - D::Diagonal, p; + D::AbstractMatrix, p; atol = zero(real(eltype(D))), rtol = iszero(atol) ? eps(real(eltype(D)))^(3 // 4) : zero(real(eltype(D))) ) - σ = D.diag + isdiag(D) || throw( + ArgumentError("pow_diag_safe requires a diagonal-structured matrix") + ) + σ = MAK.diagview(D) tol = max(atol, rtol * maximum(abs, σ; init = zero(real(eltype(D))))) - return Diagonal(map(d -> abs(d) < tol ? zero(d) : real(d)^p, σ)) + return MAK.diagonal(map(d -> abs(d) < tol ? zero(d) : real(d)^p, σ)) end """ - sqrt_diag_safe(D::Diagonal; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) + sqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(1//2) -Square root of a diagonal matrix `D`, equivalent to +Square root of a diagonal-structured matrix `D`, equivalent to `pow_diag_safe(D, 1//2; atol, rtol)`. ## Keyword arguments $(_clamp_kwargs_doc("D")) """ -sqrt_diag_safe(D::Diagonal; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) +sqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, 1 // 2; kwargs...) """ - invsqrt_diag_safe(D::Diagonal; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) + invsqrt_diag_safe(D::AbstractMatrix; atol=0, rtol=eps(real(eltype(D)))^(3//4)) -> D^(-1//2) -Inverse square root of a diagonal matrix `D`, treating diagonal entries -below tolerance as zero (Moore-Penrose convention). Equivalent to +Inverse square root of a diagonal-structured matrix `D`, treating diagonal +entries below tolerance as zero (Moore-Penrose convention). Equivalent to `pow_diag_safe(D, -1//2; atol, rtol)`. ## Keyword arguments $(_clamp_kwargs_doc("D")) """ -invsqrt_diag_safe(D::Diagonal; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) +invsqrt_diag_safe(D::AbstractMatrix; kwargs...) = pow_diag_safe(D, -1 // 2; kwargs...) """ powh_safe(M::AbstractMatrix, p; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^p Raise an approximately Hermitian positive semi-definite matrix to the -power `p`. Computed via the eigendecomposition `M = V * D * V'` as -`V * pow_diag_safe(D, p; atol, rtol) * V'`. +power `p`. For diagonal-structured `M` (`isdiag(M) == true`), dispatches +to [`pow_diag_safe`](@ref) and skips the eigendecomposition. Otherwise, +computes via `M = V * D * V'` as `V * pow_diag_safe(D, p; atol, rtol) * V'`. ## Keyword arguments @@ -162,12 +166,11 @@ power `p`. Computed via the eigendecomposition `M = V * D * V'` as $(_clamp_kwargs_doc("M")) """ function powh_safe(M::AbstractMatrix, p; alg = nothing, kwargs...) + isdiag(M) && return pow_diag_safe(M, p; kwargs...) D, V = MAK.eigh_full(M, MAK.select_algorithm(MAK.eigh_full, M, alg)) return V * pow_diag_safe(D, p; kwargs...) * V' end -powh_safe(D::Diagonal, p; kwargs...) = pow_diag_safe(D, p; kwargs...) - """ sqrth_safe(M::AbstractMatrix; alg=nothing, atol=0, rtol=eps(real(eltype(M)))^(3//4)) -> M^(1//2) From 8ee9aa97c1dc2382ee442389444e44e23f75e32c Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 16:26:01 -0400 Subject: [PATCH 17/19] Cover new pow_diag_safe contract in MatrixAlgebra tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch the gram and powh test setup to A = B' * B (matching the rank-first A ≈ X' * X convention used in the doctests). - Add coverage for the new pow_diag_safe contract: - dense diagonal-structured input goes through the isdiag fast-path (result still typed as Diagonal via MAK.diagonal); - direct call on Matrix(Diagonal(...)) matches the Diagonal call; - non-diagonal input raises ArgumentError via the isdiag guard. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/test_matrixalgebra.jl | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index a6c869bb..4c4a21ab 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -293,7 +293,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) n = 5 # Full-rank Hermitian PSD. B = randn(elt, n, n) - A = B * B' + A = B' * B X = MatrixAlgebra.gram_eigh_full(A) @test X' * X ≈ A @test size(X) == (n, n) @@ -310,8 +310,8 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # X * Y is the projector onto the rank-k subspace (idempotent, # rank k), and P * X ≈ X (Moore–Penrose). k = 3 - Brd = randn(elt, n, k) - Ard = Brd * Brd' + Brd = randn(elt, k, n) + Ard = Brd' * Brd Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( Ard; rtol = sqrt(eps(real(elt))) ) @@ -324,7 +324,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "powh_safe / sqrth_safe / invsqrth_safe" begin n = 4 B = randn(elt, n, n) - A = B * B' + A = B' * B sqrtA = MatrixAlgebra.sqrth_safe(A) @test sqrtA * sqrtA ≈ A @test sqrtA ≈ sqrtA' @@ -339,5 +339,20 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) D = Diagonal(rand(real(elt), n)) @test MatrixAlgebra.sqrth_safe(D) ≈ MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `isdiag` fast-path: a dense matrix that happens to be diagonal- + # structured goes through `pow_diag_safe`, not `eigh_full`, and + # the result is still a `Diagonal` (from `MAK.diagonal`). + Mdiag = Matrix(D) + sqrtMdiag = MatrixAlgebra.powh_safe(Mdiag, 1 // 2) + @test sqrtMdiag isa Diagonal + @test sqrtMdiag ≈ MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `pow_diag_safe` directly on a diagonal-structured `AbstractMatrix`. + @test MatrixAlgebra.pow_diag_safe(Mdiag, 1 // 2) ≈ + MatrixAlgebra.pow_diag_safe(D, 1 // 2) + + # `pow_diag_safe` rejects non-diagonal inputs. + @test_throws ArgumentError MatrixAlgebra.pow_diag_safe(A, 1 // 2) end end From 21e58fde9e6ad57d32643227d6540be3102996aa Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 16:42:40 -0400 Subject: [PATCH 18/19] Cross-link MatrixAlgebra gram entries from the tensor-layer docstrings The tensor-layer gram_eigh_full and gram_eigh_full_with_pinv docstrings now link to their MatrixAlgebra submodule counterparts via proper Documenter ref links instead of bare code spans. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/factorizations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/factorizations.jl b/src/factorizations.jl index d2cc30af..6d5843dc 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -467,7 +467,7 @@ true ``` See also [`gram_eigh_full_with_pinv`](@ref) and -`MatrixAlgebra.gram_eigh_full`. +[`MatrixAlgebra.gram_eigh_full`](@ref). """ gram_eigh_full @@ -529,7 +529,7 @@ julia> contract((:r, :s), X, (:r, :a, :b), Y, (:a, :b, :s)) ≈ I true ``` -See also `MatrixAlgebra.gram_eigh_full_with_pinv`. +See also [`MatrixAlgebra.gram_eigh_full_with_pinv`](@ref). """ gram_eigh_full_with_pinv From 653badfdacafee956df03fe641ea642810a2566c Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Fri, 29 May 2026 17:08:05 -0400 Subject: [PATCH 19/19] Stabilize gram tests against Float32 random rank deficiency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A square random `B` for `A = B' * B` can produce a smallest eigenvalue below the default `eps(Float32)^(3//4)` rtol clamp on some seeds, which makes `gram_eigh_full_with_pinv` see A as rank n-1 and break `X * Y ≈ I(n)`. Switch to a 2n×n tall factor (well-conditioned) and a fixed StableRNG so the test is reproducible across platforms. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/test_matrixalgebra.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 4c4a21ab..d4a6c926 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -291,8 +291,12 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "gram_eigh_full" begin n = 5 - # Full-rank Hermitian PSD. - B = randn(elt, n, n) + # Full-rank Hermitian PSD. Use a tall random factor so `B' * B` + # is comfortably full rank even at Float32 precision (a square + # random `B` can produce a `B' * B` whose smallest eigenvalue + # falls below the default rtol clamp on some seeds). + rng = StableRNG(123) + B = randn(rng, elt, 2n, n) A = B' * B X = MatrixAlgebra.gram_eigh_full(A) @test X' * X ≈ A @@ -310,7 +314,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # X * Y is the projector onto the rank-k subspace (idempotent, # rank k), and P * X ≈ X (Moore–Penrose). k = 3 - Brd = randn(elt, k, n) + Brd = randn(rng, elt, k, n) Ard = Brd' * Brd Xrd, Yrd = MatrixAlgebra.gram_eigh_full_with_pinv( Ard; rtol = sqrt(eps(real(elt))) @@ -323,7 +327,8 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "powh_safe / sqrth_safe / invsqrth_safe" begin n = 4 - B = randn(elt, n, n) + rng = StableRNG(123) + B = randn(rng, elt, 2n, n) A = B' * B sqrtA = MatrixAlgebra.sqrth_safe(A) @test sqrtA * sqrtA ≈ A