From 60edde4f3a80c4bf7e2eb5c6437327bd4e005850 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 6 May 2026 09:32:26 -0400 Subject: [PATCH 1/8] Add SketchingStrategy --- Project.toml | 1 + src/MatrixAlgebraKit.jl | 8 ++ src/algorithms.jl | 35 +++++- src/implementations/sketching.jl | 64 ++++++++++ src/interface/sketching.jl | 89 ++++++++++++++ src/interface/svd.jl | 11 +- test/sketching.jl | 17 +++ test/testsuite/TestSuite.jl | 9 +- test/testsuite/decompositions/sketching.jl | 131 +++++++++++++++++++++ 9 files changed, 360 insertions(+), 5 deletions(-) create mode 100644 src/implementations/sketching.jl create mode 100644 src/interface/sketching.jl create mode 100644 test/sketching.jl create mode 100644 test/testsuite/decompositions/sketching.jl diff --git a/Project.toml b/Project.toml index 29f0e2bc8..139d8ad18 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ authors = ["Jutho Haegeman , Lukas Devos, Katharine Hya [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 22bf79e9c..331b58bc9 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -9,6 +9,8 @@ using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt +using Random: Random + export isisometric, isunitary, ishermitian, isantihermitian export diagview, diagonal @@ -30,6 +32,8 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! +export left_sketch, right_sketch +export left_sketch!, right_sketch! export Householder, Native_HouseholderQR, Native_HouseholderLQ export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar @@ -50,6 +54,8 @@ export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, export notrunc, truncrank, trunctol, truncerror, truncfilter +export SketchedAlgorithm, SketchingStrategy, GaussianSketching + @static if VERSION >= v"1.11.0-DEV.469" eval( Expr( @@ -101,6 +107,7 @@ include("interface/truncation.jl") include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") +include("interface/sketching.jl") include("interface/eig.jl") include("interface/eigh.jl") include("interface/gen_eig.jl") @@ -113,6 +120,7 @@ include("implementations/truncation.jl") include("implementations/qr.jl") include("implementations/lq.jl") include("implementations/svd.jl") +include("implementations/sketching.jl") include("implementations/eig.jl") include("implementations/eigh.jl") include("implementations/gen_eig.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index 65a25bc18..ce413181f 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -242,12 +242,21 @@ default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK() """ abstract type TruncationStrategy end -Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation. +Supertype to denote different strategies for truncated decompositions. See also [`truncate`](@ref) """ abstract type TruncationStrategy end +""" + abstract type SketchingStrategy <: AbstractAlgorithm end + +Supertype to denote different sketching strategies, used both as standalone algorithms for +[`left_sketch!`](@ref) and [`right_sketch!`](@ref) and as the `sketch` field of a +[`SketchedAlgorithm`](@ref) for self-truncating SVD. +""" +abstract type SketchingStrategy <: AbstractAlgorithm end + @doc """ MatrixAlgebraKit.select_truncation(trunc) @@ -317,7 +326,7 @@ See also [`findtruncated`](@ref) and [`findtruncated_svd`](@ref) for determining function truncate end """ - TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) + TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy) Generic wrapper type for algorithms that consist of first using `alg`, followed by a truncation through `trunc`. @@ -327,7 +336,27 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm trunc::T end +""" + SketchedAlgorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy, trunc::TruncationStrategy) + +Generic wrapper type for self-truncating algorithms that produce an approximate low-rank +factorization by first applying a sketching operation specified by `sketch`, then computing +a small dense decomposition of the projected matrix using `alg`. The `driver` selects the +backend (e.g. `DefaultDriver()`, `CUSOLVER()`). +""" +struct SketchedAlgorithm{A <: AbstractAlgorithm, S <: SketchingStrategy, T <: TruncationStrategy} <: AbstractAlgorithm + alg::A + sketch::S + trunc::T +end + does_truncate(::TruncatedAlgorithm) = true +does_truncate(::SketchedAlgorithm) = true + +truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy) = + TruncatedAlgorithm(alg, trunc) +truncated_algorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy) = + SketchedAlgorithm(sketch, alg, DefaultDriver()) # Utility macros # -------------- @@ -535,7 +564,7 @@ macro check_size(x, sz, size = :size) szx = $size($x) $err = $msgstart * string(szx) * " instead of expected value " * string($sz) - szx == $sz || throw(DimensionMismatch($err)) + (szx == $sz)::Bool || throw(DimensionMismatch($err)) end ) end diff --git a/src/implementations/sketching.jl b/src/implementations/sketching.jl new file mode 100644 index 000000000..14731d69e --- /dev/null +++ b/src/implementations/sketching.jl @@ -0,0 +1,64 @@ +# Inputs / defaults / outputs +# --------------------------- +copy_input(::typeof(left_sketch), A) = A +copy_input(::typeof(right_sketch), A) = A + +function initialize_output(::typeof(left_sketch!), A::AbstractMatrix, alg::GaussianSketching) + m, n = size(A) + k = min(alg.howmany, m, n) + T = float(eltype(A)) + Q = similar(A, T, (m, k)) + B = similar(A, T, (k, n)) + return Q, B +end +function initialize_output(::typeof(right_sketch!), A::AbstractMatrix, alg::GaussianSketching) + return initialize_output(left_sketch!, A, alg) +end + +function check_input(::typeof(left_sketch!), A::AbstractMatrix, (Q, B), alg::GaussianSketching) + m, n = size(A) + k = min(alg.howmany, m, n) + @assert Q isa AbstractMatrix + @check_size(Q, (m, k)) + @check_scalar(Q, A, float) + @assert B isa AbstractMatrix + @check_size(B, (k, n)) + @check_scalar(B, A, float) + return nothing +end +function check_input(::typeof(right_sketch!), A::AbstractMatrix, BPᴴ, alg::GaussianSketching) + check_input(left_sketch!, A, BPᴴ, alg) + return nothing +end + +# Gaussian sketching, native implementation +# ----------------------------------------- +function left_sketch!(A::AbstractMatrix, QB, alg::GaussianSketching) + check_input(left_sketch!, A, QB, alg) + Q, B = QB + Ω = Random.randn!(alg.rng, similar(Q, (size(A, 2), size(Q, 2)))) + Y = A * Ω + R = similar(Y, (0, 0)) + Q, _ = left_orth!(Y, (Q, R)) + for _ in 2:alg.numiter + mul!(Ω, A', Q) + mul!(Y, A, Ω) + Q, _ = left_orth!(Y, (Q, R)) + end + return Q, mul!(B, Q', A) +end + +function right_sketch!(A::AbstractMatrix, BPᴴ, alg::GaussianSketching) + check_input(right_sketch!, A, BPᴴ, alg) + B, Pᴴ = BPᴴ + Ω = Random.randn!(alg.rng, similar(Pᴴ, (size(Pᴴ, 1), size(A, 1)))) + Y = Ω * A + L = similar(Y, (0, 0)) + _, Pᴴ = right_orth!(Y, (L, Pᴴ)) + for _ in 2:alg.numiter + mul!(Ω, Pᴴ, A') + mul!(Y, Ω, A) + _, Pᴴ = right_orth!(Y, (L, Pᴴ)) + end + return mul!(B, A, Pᴴ'), Pᴴ +end diff --git a/src/interface/sketching.jl b/src/interface/sketching.jl new file mode 100644 index 000000000..8e370b443 --- /dev/null +++ b/src/interface/sketching.jl @@ -0,0 +1,89 @@ +# Gaussian sketching +# ------------------ +""" + GaussianSketching(howmany; numiter, rng) + +Sketching strategy using a Gaussian random matrix with optional power iterations to improve +accuracy on slowly-decaying spectra. + +## Fields +- `howmany::Int`: number of singular values to compute. +- `numiter::Int`: number of power iterations (`numiter ≥ 1`; the first counts as the initial + sketch). +- `rng::AbstractRNG`: random number generator used to draw the Gaussian sketch matrix. +""" +struct GaussianSketching{RNG <: Random.AbstractRNG} <: SketchingStrategy + howmany::Int + numiter::Int + rng::RNG +end + +function GaussianSketching(howmany::Integer; numiter::Integer = 2, rng::Random.AbstractRNG = Random.default_rng()) + howmany ≥ 0 || throw(ArgumentError("howmany must be non-negative")) + numiter ≥ 1 || throw(ArgumentError("numiter must be at least 1 ($numiter)")) + return GaussianSketching{typeof(rng)}(howmany, numiter, rng) +end + +# Entry points +# ------------ +""" + left_sketch(A; howmany, kwargs...) -> Q, B + left_sketch(A, alg::AbstractAlgorithm) -> Q, B + left_sketch!(A, [QB]; howmany, kwargs...) -> Q, B + left_sketch!(A, [QB], alg::AbstractAlgorithm) -> Q, B + +Compute an isometric matrix `Q` (orthonormal columns) of size m×k, whose column span approximates the range of `A` of size m×n. +Also create the core factor `B = Q' * A`. +Here `k = howmany` is the sketch dimension. + +The keyword arguments construct a [`GaussianSketching`](@ref) strategy unless an explicit `alg::SketchingStrategy` is supplied. +`howmany` is required. + +!!! note + The bang method `left_sketch!` optionally accepts the output matrices `Q, B` and possibly destroys the input matrix `A`. + Always use the return value of the function as it may not always be possible to use the provided `Q, B` as output. + +See also [`right_sketch(!)`](@ref right_sketch) and [`SketchedAlgorithm`](@ref). +""" +@functiondef left_sketch + +""" + right_sketch(A; howmany, kwargs...) -> B, Pᴴ + right_sketch(A, alg::AbstractAlgorithm) -> B, Pᴴ + right_sketch!(A, [BPᴴ]; howmany, kwargs...) -> B, Pᴴ + right_sketch!(A, [BPᴴ], alg::AbstractAlgorithm) -> B, Pᴴ + +Compute a right-isometric matrix `Pᴴ` (orthonormal rows) of size k×n, whose row span approximates the range of `A` of size m×n. +Also create the core factor `B = A * Pᴴ'` +Here `k = howmany` is the sketch dimension. + +The keyword arguments construct a [`GaussianSketching`](@ref) strategy unless an explicit `alg::SketchingStrategy` is supplied. +`howmany` is required. + +!!! note + The bang method `right_sketch!` optionally accepts the output matrices `BPᴴ` and possibly destroys the input matrix `A`. + Always use the return value of the function as it may not always be possible to use the provided `BPᴴ` as output. + +See also [`left_sketch(!)`](@ref left_sketch) and [`SketchedAlgorithm`](@ref). +""" +@functiondef right_sketch + +# Algorithm selection +# ------------------- +default_sketch_algorithm(A; kwargs...) = default_sketch_algorithm(typeof(A); kwargs...) +default_sketch_algorithm(T::Type; kwargs...) = throw(MethodError(default_sketch_algorithm, (T,))) +function default_sketch_algorithm(::Type{T}; howmany, kwargs...) where {T <: AbstractMatrix} + return GaussianSketching(howmany; kwargs...) +end +function default_sketch_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}; kwargs...) where {T, N, A} + return default_sketch_algorithm(A; kwargs...) +end +function default_sketch_algorithm(::Type{<:SubArray{T, N, A}}; kwargs...) where {T, N, A} + return default_sketch_algorithm(A; kwargs...) +end + +for f in (:left_sketch!, :right_sketch!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_sketch_algorithm(A; kwargs...) + end +end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 7f55e443f..0c1911897 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -184,9 +184,18 @@ for f in (:svd_trunc!, :svd_trunc_no_error!) isnothing(trunc) || throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) return alg + elseif alg isa SketchedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `SketchedAlgorithm`")) + return alg else alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + trunc = select_truncation(trunc) + if trunc isa TruncationStrategy + return truncated_algorithm(alg_svd, trunc) + else + throw(ArgumentError("invalid truncation $trunc")) + end end end end diff --git a/test/sketching.jl b/test/sketching.jl new file mode 100644 index 000000000..16a46a458 --- /dev/null +++ b/test/sketching.jl @@ -0,0 +1,17 @@ +using MatrixAlgebraKit + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +# CPU tests +# --------- +if !is_buildkite + @testset "Sketching ($T, $m, $n)" for T in BLASFloats, (m, n) in ((100, 40), (40, 100), (60, 60)) + TestSuite.seed_rng!(123) + TestSuite.test_sketching(T, (m, n)) + end +end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 36ae68304..a37c7b9c7 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -95,13 +95,19 @@ function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz.. V, C = left_orth!(A; trunc) return mul!(A, V, C) end - function instantiate_rank_deficient_matrix(::Type{T}, sz; trunc = truncrank(div(min(sz...), 2))) where {T <: Diagonal} A = instantiate_matrix(eltype(T), sz) V, C = left_orth!(A; trunc) return Diagonal(diag(mul!(A, V, C))) end +function instantiate_almost_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2)), atol::Real = 0, rtol::Real = precision(T)) + A = instantiate_rank_deficient_matrix(T, sz; trunc) + noise = normalize(instantiate_matrix(T, sz)) + A .+= max(atol, rtol * norm(A)) * noise + return A +end + include("ad_utils.jl") include("projections.jl") @@ -116,6 +122,7 @@ include("decompositions/eig.jl") include("decompositions/eigh.jl") include("decompositions/orthnull.jl") include("decompositions/svd.jl") +include("decompositions/sketching.jl") # Mooncake # -------- diff --git a/test/testsuite/decompositions/sketching.jl b/test/testsuite/decompositions/sketching.jl new file mode 100644 index 000000000..27376a6b7 --- /dev/null +++ b/test/testsuite/decompositions/sketching.jl @@ -0,0 +1,131 @@ +using TestExtras + +function test_sketching(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "sketching $summary_str" begin + test_left_sketch(T, sz; kwargs...) + test_right_sketch(T, sz; kwargs...) + end +end + +function test_left_sketch( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "left_sketch $summary_str" begin + m, n = sz + r = min(min(m, n) ÷ 4, 5) + r > 0 || return + + A = instantiate_almost_rank_deficient_matrix(T, sz; trunc = truncrank(r), atol, rtol) + Ac = deepcopy(A) + k = min(r, m, n) + + # Does the elementary functionality work + Q, B = @testinferred left_sketch(A; howmany = r) + @test size(Q) == (m, k) + @test eltype(Q) === float(eltype(T)) + @test isisometric(Q; rtol, atol) + @test size(B) == (k, n) + @test eltype(B) === float(eltype(T)) + @test B ≈ Q' * A atol = atol rtol = rtol + @test A ≈ Q * B atol = atol rtol = rtol + @test A == Ac + + # Can I pass in outputs + Q, B = @testinferred left_sketch!(deepcopy(A), (Q, B); howmany = r) + @test size(Q) == (m, k) + @test eltype(Q) === float(eltype(T)) + @test isisometric(Q; rtol, atol) + @test size(B) == (k, n) + @test eltype(B) === float(eltype(T)) + @test B ≈ Q' * A atol = atol rtol = rtol + @test A ≈ Q * B atol = atol rtol = rtol + + # Can I pass in keywords + rng = MersenneTwister(3) + Q, B = @testinferred left_sketch(A; howmany = r, rng) + rng = MersenneTwister(3) + Q′, B′ = @testinferred left_sketch(A; howmany = r, rng) + @test Q == Q′ + @test B == B′ + + # Can I pass in algorithms + rng = MersenneTwister(3) + alg = GaussianSketching(r; rng) + Q′, B′ = @testinferred left_sketch(A, alg) + @test Q == Q′ + @test B == B′ + + # Do power iterations improve accuracy + Aflat = instantiate_matrix(T, sz) + Q1, B1 = left_sketch(Aflat, GaussianSketching(r; numiter = 1, rng = MersenneTwister(123))) + Q5, B5 = left_sketch(Aflat, GaussianSketching(r; numiter = 5, rng = MersenneTwister(123))) + e1 = norm(Aflat - Q1 * B1) / norm(Aflat) + e5 = norm(Aflat - Q5 * B5) / norm(Aflat) + @test e5 ≤ e1 + rtol + end +end + +function test_right_sketch( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "right_sketch $summary_str" begin + m, n = sz + r = min(min(m, n) ÷ 4, 5) + r > 0 || return + + A = instantiate_almost_rank_deficient_matrix(T, sz; trunc = truncrank(r), atol, rtol) + Ac = deepcopy(A) + k = min(r, m, n) + + # Does the elementary functionality work + B, Pᴴ = @testinferred right_sketch(A; howmany = r) + @test size(B) == (m, k) + @test eltype(B) === float(eltype(T)) + @test size(Pᴴ) == (k, n) + @test eltype(Pᴴ) === float(eltype(T)) + @test isisometric(Pᴴ'; rtol, atol) + @test B ≈ A * Pᴴ' atol = atol rtol = rtol + @test A ≈ B * Pᴴ atol = atol rtol = rtol + @test A == Ac + + # Can I pass in outputs + B, Pᴴ = @testinferred right_sketch!(deepcopy(A), (B, Pᴴ); howmany = r) + @test size(B) == (m, k) + @test eltype(B) === float(eltype(T)) + @test size(Pᴴ) == (k, n) + @test eltype(Pᴴ) === float(eltype(T)) + @test isisometric(Pᴴ'; rtol, atol) + @test B ≈ A * Pᴴ' atol = atol rtol = rtol + @test A ≈ B * Pᴴ atol = atol rtol = rtol + + # Can I pass in keywords + rng = MersenneTwister(3) + B, Pᴴ = @testinferred right_sketch(A; howmany = r, rng) + rng = MersenneTwister(3) + B′, Pᴴ′ = @testinferred right_sketch(A; howmany = r, rng) + @test B == B′ + @test Pᴴ == Pᴴ′ + + # Can I pass in algorithms + rng = MersenneTwister(3) + alg = GaussianSketching(r; rng) + B′, Pᴴ′ = @testinferred right_sketch(A, alg) + @test B == B′ + @test Pᴴ == Pᴴ′ + + # Do power iterations improve accuracy + Aflat = instantiate_matrix(T, sz) + B1, P1 = right_sketch(Aflat, GaussianSketching(r; numiter = 1, rng = MersenneTwister(123))) + B5, P5 = right_sketch(Aflat, GaussianSketching(r; numiter = 5, rng = MersenneTwister(123))) + e1 = norm(Aflat - B1 * P1) / norm(Aflat) + e5 = norm(Aflat - B5 * P5) / norm(Aflat) + @test e5 ≤ e1 + rtol + end +end From ec55db2317ed19890155e29a0d441dc971f0400e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 7 May 2026 14:20:31 -0400 Subject: [PATCH 2/8] implement SketchedAlgorithm for SVD --- src/algorithms.jl | 3 + src/implementations/svd.jl | 94 +++++++++++++--------------- test/decompositions/svd.jl | 25 ++++++-- test/testsuite/decompositions/svd.jl | 28 ++++++--- 4 files changed, 85 insertions(+), 65 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index ce413181f..96b224e99 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -350,6 +350,9 @@ struct SketchedAlgorithm{A <: AbstractAlgorithm, S <: SketchingStrategy, T <: Tr trunc::T end +# utility conversion constructor +TruncatedAlgorithm(alg::SketchedAlgorithm) = TruncatedAlgorithm(alg.alg, alg.trunc) + does_truncate(::TruncatedAlgorithm) = true does_truncate(::SketchedAlgorithm) = true diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 3d20e96d4..7f999906d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -285,66 +285,56 @@ function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm) return S end -# GPU logic (randomized SVD - CUSOLVER_Randomized has no CPU analog, kept as-is) -# --------------------------------------------------------------------------------- +# Sketched Logic +# -------------- +function initialize_output(::typeof(svd_trunc_no_error!), A::AbstractMatrix, alg::SketchedAlgorithm) + U, Vᴴ = initialize_output(left_sketch!, A, alg.sketch) + S = Diagonal(similar(U, real(eltype(U)), (size(U, 2),))) + return U, S, Vᴴ +end +initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::SketchedAlgorithm) = + initialize_output(svd_trunc_no_error!, A, alg) -function check_input( - ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized - ) - m, n = size(A) - minmn = min(m, n) - U, S, Vᴴ = USVᴴ +function check_input(::typeof(svd_trunc_no_error!), A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgorithm) + check_input(left_sketch!, A, (U, Vᴴ), alg.sketch) @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix - @check_size(U, (m, m)) - @check_scalar(U, A) - @check_size(S, (minmn, minmn)) - @check_scalar(S, A, real) - @check_size(Vᴴ, (n, n)) - @check_scalar(Vᴴ, A) + k = size(U, 2) + @check_size(S, (k, k)) + @check_scalar(S, U, real) return nothing end +check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm) = + check_input(svd_trunc_no_error!, A, USVᴴ, alg) -function initialize_output( - ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} - ) +function svd_trunc_no_error!(A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgorithm) + check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg) m, n = size(A) - minmn = min(m, n) - U = similar(A, (m, m)) - S = Diagonal(similar(A, real(eltype(A)), (minmn,))) - Vᴴ = similar(A, (n, n)) - return (U, S, Vᴴ) -end - -function _gpu_Xgesvdr!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... - ) - throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) -end - -function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) - U, S, Vᴴ = USVᴴ - check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg.alg) - _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) - - # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong - (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - - do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool - # the output matrices here are the same size as for svd_full! - do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) - - return Utr, Str, Vᴴtr + if m ≥ n + Q, B = left_sketch!(A, (U, Vᴴ), alg.sketch) + k = size(B, 1) + U′ = similar(B, (k, k)) + Vᴴ′ = similar(B) + USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg) + (Uout′, Sout, Vᴴout), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc) + Uout = Q * Uout′ + else + B, Pᴴ = right_sketch!(A, (U, Vᴴ), alg.sketch) + k = size(B, 2) + U′ = similar(B) + Vᴴ′ = similar(B, (k, k)) + USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg) + (Uout, Sout, Vᴴout′), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc) + Vᴴout = Vᴴout′ * Pᴴ + end + get(alg.alg.kwargs, :fixgauge, true) && gaugefix!(svd_trunc!, Uout, Vᴴout) + return Uout, Sout, Vᴴout end -function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) - Utr, Str, Vᴴtr = svd_trunc_no_error!(A, USVᴴ, alg) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - normS = norm(diagview(Str)) - normA = norm(A) - # equivalent to sqrt(normA^2 - normS^2) - # but may be more accurate - ϵ = sqrt((normA + normS) * abs(normA - normS)) - return Utr, Str, Vᴴtr, ϵ +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm) + U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg) + Na = norm(A) + Ns = norm(S) + return U, S, Vᴴ, sqrt(max(zero(Na), (Na + Ns) * (Na - Ns))) end # Deprecations diff --git a/test/decompositions/svd.jl b/test/decompositions/svd.jl index c69ed3a0e..b8401d2fa 100644 --- a/test/decompositions/svd.jl +++ b/test/decompositions/svd.jl @@ -18,17 +18,27 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" # CPU tests # --------- if !is_buildkite - # LAPACK algorithms: - for T in BLASFloats, m in (0, 54), n in (0, 37, m, 63) + @testset "LAPACK algorithms ($T, $m, $n)" for T in BLASFloats, m in (0, 54), n in (0, 37, m, 63) TestSuite.seed_rng!(123) LAPACK_SVD_ALGS = (QRIteration(), DivideAndConquer(), SafeDivideAndConquer(; fixgauge = true)) TestSuite.test_svd(T, (m, n)) TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) @static if VERSION > v"1.11-" # Jacobi broken on 1.10 - TestSuite.test_svd_algs(T, (m, n), (LAPACK_Jacobi(),); test_full = false, test_vals = false) + TestSuite.test_svd_algs(T, (m, n), (Jacobi(),); test_full = false, test_vals = false) end end + # Sketched algorithms + for T in BLASFloats + m, n = 54, 63 + rtol = sqrt(TestSuite.precision(T)) # extra square root + algs = [ + SketchedAlgorithm(DefaultAlgorithm(), GaussianSketching(m ÷ 2, numiter = 4), truncrank(m ÷ 4)), + ] + TestSuite.test_sketched_svd(T, (m, n), algs; rtol) + TestSuite.test_sketched_svd(T, (n, m), algs) + end + # Generic floats: for T in GenericFloats, m in (0, 54), n in (0, 37, m, 63) TestSuite.seed_rng!(123) @@ -47,7 +57,7 @@ end # CUDA tests # ------------ -if CUDA.functional() +if false # CUDA.functional() # LAPACK algorithms: for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) TestSuite.seed_rng!(123) @@ -62,7 +72,12 @@ if CUDA.functional() k = 5 p = min(m, n) - k - 2 p > 0 || continue - TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) + cusolver_sketch = SketchedAlgorithm( + GaussianSketching(k; oversampling = p, niters = 20), + DefaultAlgorithm(), + MatrixAlgebraKit.CUSOLVER(), + ) + TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (cusolver_sketch,)) end # Diagonal: diff --git a/test/testsuite/decompositions/svd.jl b/test/testsuite/decompositions/svd.jl index 4b89d4973..d05bf62eb 100644 --- a/test/testsuite/decompositions/svd.jl +++ b/test/testsuite/decompositions/svd.jl @@ -363,15 +363,27 @@ function test_svd_trunc_algs( end end -function test_randomized_svd(T::Type, sz, algs; kwargs...) +function test_sketched_svd( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(eltype(T)), kwargs... + ) summary_str = testargs_summary(T, sz) - return @testset "randomized svd_trunc! algorithm $alg $summary_str" for alg in algs - A = instantiate_matrix(T, sz) + return @testset "sketched svd_trunc! algorithm $alg $summary_str" for alg in algs + @assert alg isa SketchedAlgorithm "Invalid sketched algorithm type: $(typeof(alg))" + + A = instantiate_rank_deficient_matrix(T, sz; alg.trunc) + A += max(atol, rtol * norm(A)) * instantiate_matrix(T, sz) Ac = deepcopy(A) - m, n = size(A) - minmn = min(m, n) - S₀ = collect(svd_vals(A)) - U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; alg) - @test collect(diagview(S1))[1:alg.alg.k] ≈ S₀[1:alg.alg.k] + + alg2 = MatrixAlgebraKit.TruncatedAlgorithm(alg) + + U, S, Vᴴ, ϵ = @testinferred svd_trunc(A, alg) + @test Ac == A + + U′, S′, Vᴴ′, ϵ′ = svd_trunc(A, alg2) + @test U ≈ U′ atol = atol rtol = rtol + @test S ≈ S′ atol = atol rtol = rtol + @test Vᴴ ≈ Vᴴ′ atol = atol rtol = rtol + @test ϵ ≈ ϵ′ atol = atol rtol = rtol end end From 1cd48c0051838c272200afa4ea8dc3b7041bed60 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 7 May 2026 14:38:43 -0400 Subject: [PATCH 3/8] Rework algorithm selection logic --- src/algorithms.jl | 30 ++++++++++++++++++++++++------ src/implementations/orthnull.jl | 6 +++--- src/interface/svd.jl | 19 +++++++------------ 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 96b224e99..cec142fec 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -292,7 +292,26 @@ function select_null_truncation(trunc) elseif trunc isa TruncationStrategy return trunc else - return throw(ArgumentError("Unknown truncation strategy: $trunc")) + throw(ArgumentError("Unknown truncation strategy: $trunc")) + end +end + +@doc """ + MatrixAlgebraKit.select_sketching(A, sketch) + +Construct a [`SketchingStrategy`](@ref) for `A` from the given `NamedTuple` of keywords or input strategy `sketch`. +""" select_sketching + +@inline select_sketching(A, sketch) = select_sketching(typeof(A), sketch) +@inline function select_sketching(::Type{A}, sketch) where {A} + if isnothing(sketch) + return nothing + elseif sketch isa SketchingStrategy + return sketch + elseif sketch isa NamedTuple + return select_algorithm(left_sketch!, A; sketch...) + else + throw(ArgumentError("Unknown sketching strategy: $sketch")) end end @@ -331,7 +350,7 @@ function truncate end Generic wrapper type for algorithms that consist of first using `alg`, followed by a truncation through `trunc`. """ -struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm +struct TruncatedAlgorithm{A <: AbstractAlgorithm, T <: TruncationStrategy} <: AbstractAlgorithm alg::A trunc::T end @@ -356,10 +375,9 @@ TruncatedAlgorithm(alg::SketchedAlgorithm) = TruncatedAlgorithm(alg.alg, alg.tru does_truncate(::TruncatedAlgorithm) = true does_truncate(::SketchedAlgorithm) = true -truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy) = - TruncatedAlgorithm(alg, trunc) -truncated_algorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy) = - SketchedAlgorithm(sketch, alg, DefaultDriver()) +truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy, sketch = nothing) = + isnothing(sketch) ? TruncatedAlgorithm(alg, trunc) : SketchedAlgorithm(; alg, sketch, trunc) + # Utility macros # -------------- diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 1a7f88835..95e2ec366 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -116,8 +116,8 @@ function right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm}) return Nᴴ end -# randomized algorithms don't currently work for smallest values: -left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = +# randomized (sketched) algorithms don't currently work for smallest values: +left_null!(A, N, alg::LeftNullViaSVD{<:SketchedAlgorithm}) = throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) -right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = +right_null!(A, Nᴴ, alg::RightNullViaSVD{<:SketchedAlgorithm}) = throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 0c1911897..f5869b4ce 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -179,23 +179,18 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end for f in (:svd_trunc!, :svd_trunc_no_error!) - @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm + @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, sketch = nothing, kwargs...) + if alg isa TruncatedAlgorithm || alg isa SketchedAlgorithm isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - elseif alg isa SketchedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `SketchedAlgorithm`")) + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm` or `SketchedAlgorithm`")) + isnothing(sketch) || + throw(ArgumentError("`sketch` can't be specified when `alg` is a `TruncatedAlgorithm` or `SketchedAlgorithm`")) return alg else alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) trunc = select_truncation(trunc) - if trunc isa TruncationStrategy - return truncated_algorithm(alg_svd, trunc) - else - throw(ArgumentError("invalid truncation $trunc")) - end + sketch = select_sketching(A, sketch) + return truncated_algorithm(alg_svd, trunc, sketch) end end end From d111cd9b4225164a939e2f70222436c4c1d475ba Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 8 May 2026 13:31:34 -0400 Subject: [PATCH 4/8] refactor + deprecate CUSOLVER randomized SVD --- .../MatrixAlgebraKitCUDAExt.jl | 33 ++++++++- src/algorithms.jl | 20 +++-- src/implementations/svd.jl | 74 +++++++++++++++---- src/interface/decompositions.jl | 2 - test/decompositions/svd.jl | 2 +- 5 files changed, 104 insertions(+), 27 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 25a739df0..cb1230247 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -4,11 +4,12 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm +using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, TruncationByOrder, AbstractAlgorithm +using MatrixAlgebraKit: GaussianSketching, SketchingStrategy, SketchedAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, geev! -import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank +import MatrixAlgebraKit: _sylvester, svd_rank using CUDA, CUDA.cuBLAS using CUDA: i32 using LinearAlgebra @@ -17,6 +18,7 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() +MatrixAlgebraKit.default_driver(::Type{<:SketchedAlgorithm}, ::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return QRIteration(; kwargs...) @@ -50,8 +52,31 @@ end gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...) -_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...) +# Sketched SVD via cuSOLVER's gesvdr kernel +function gesvdr!( + ::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; + sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm() + ) + isempty(A) && return U, S, Vᴴ + m, n = size(A); minmn = min(m, n) + k = trunc.howmany + 1 ≤ k ≤ minmn || + throw(ArgumentError("trunc.howmany=$k must satisfy 1 ≤ k ≤ min(size(A))=$minmn")) + p = sketch.howmany - k + p ≥ 0 || throw( + ArgumentError( + "sketch.howmany=$(sketch.howmany) must be ≥ trunc.howmany=$k" + ) + ) + p = min(p, minmn - k - 1) + niters = sketch.numiter - 1 + + Uk = view(U, :, 1:k) + Vᴴk = view(Vᴴ, 1:k, :) + Sk = view(diagview(S), 1:k) + YACUSOLVER.gesvdr!(A, Sk, Uk, Vᴴk; k, p, niters) + return Uk, S, Vᴴk +end geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, Dd, V) diff --git a/src/algorithms.jl b/src/algorithms.jl index cec142fec..dba747127 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -356,17 +356,25 @@ struct TruncatedAlgorithm{A <: AbstractAlgorithm, T <: TruncationStrategy} <: Ab end """ - SketchedAlgorithm(alg::AbstractAlgorithm, sketch::SketchingStrategy, trunc::TruncationStrategy) + SketchedAlgorithm(; + alg::AbstractAlgorithm, sketch::SketchingStrategy, + trunc::TruncationStrategy, driver::Driver = DefaultDriver() + ) Generic wrapper type for self-truncating algorithms that produce an approximate low-rank factorization by first applying a sketching operation specified by `sketch`, then computing a small dense decomposition of the projected matrix using `alg`. The `driver` selects the -backend (e.g. `DefaultDriver()`, `CUSOLVER()`). -""" -struct SketchedAlgorithm{A <: AbstractAlgorithm, S <: SketchingStrategy, T <: TruncationStrategy} <: AbstractAlgorithm - alg::A +backend implementing the sketched factorization (e.g. `Native()` for the generic +sketch-then-decompose pipeline, `CUSOLVER()` for the fused `gesvdr` kernel). +""" +@kwdef struct SketchedAlgorithm{ + A <: AbstractAlgorithm, S <: SketchingStrategy, + T <: TruncationStrategy, D <: Driver, + } <: AbstractAlgorithm + alg::A = DefaultAlgorithm() sketch::S - trunc::T + trunc::T = notrunc() + driver::D = DefaultDriver() end # utility conversion constructor diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 7f999906d..f67c1cd21 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -308,35 +308,47 @@ check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::SketchedAlgori function svd_trunc_no_error!(A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgorithm) check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg) + return gesvdr!(alg.driver, A, S, U, Vᴴ; alg.sketch, alg.alg, alg.trunc) +end + +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm) + U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg) + Na = norm(A) + Ns = norm(S) + return U, S, Vᴴ, sqrt(max(zero(Na), (Na + Ns) * (Na - Ns))) +end + +# gesvdr! drivers +# --------------- +default_driver(::Type{<:SketchedAlgorithm}, ::Type{<:AbstractArray}) = Native() + +gesvdr!(::DefaultDriver, A, S, U, Vᴴ; kwargs...) = + gesvdr!(default_driver(SketchedAlgorithm, A), A, S, U, Vᴴ; kwargs...) + +function gesvdr!( + ::Native, A::AbstractMatrix, S, U, Vᴴ; + sketch::SketchingStrategy, alg::AbstractAlgorithm, + trunc::TruncationStrategy + ) m, n = size(A) if m ≥ n - Q, B = left_sketch!(A, (U, Vᴴ), alg.sketch) + Q, B = left_sketch!(A, (U, Vᴴ), sketch) k = size(B, 1) U′ = similar(B, (k, k)) Vᴴ′ = similar(B) - USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg) - (Uout′, Sout, Vᴴout), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc) + Uout′, Sout, Vᴴout, _ = svd_trunc!(B, (U′, S, Vᴴ′), TruncatedAlgorithm(alg, trunc)) Uout = Q * Uout′ else - B, Pᴴ = right_sketch!(A, (U, Vᴴ), alg.sketch) + B, Pᴴ = right_sketch!(A, (U, Vᴴ), sketch) k = size(B, 2) U′ = similar(B) Vᴴ′ = similar(B, (k, k)) - USVᴴ_inner = svd_compact!(B, (U′, S, Vᴴ′), alg.alg) - (Uout, Sout, Vᴴout′), _ = truncate(svd_trunc!, USVᴴ_inner, alg.trunc) + Uout, Sout, Vᴴout′, _ = svd_trunc!(B, (U′, S, Vᴴ′), TruncatedAlgorithm(alg, trunc)) Vᴴout = Vᴴout′ * Pᴴ end - get(alg.alg.kwargs, :fixgauge, true) && gaugefix!(svd_trunc!, Uout, Vᴴout) return Uout, Sout, Vᴴout end -function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm) - U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg) - Na = norm(A) - Ns = norm(S) - return U, S, Vᴴ, sqrt(max(zero(Na), (Na + Ns) * (Na - Ns))) -end - # Deprecations # ------------ for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, :Bisection) @@ -380,6 +392,40 @@ for (algtype, newtype, drivertype) in ( end end +# CUSOLVER_Randomized → SketchedAlgorithm with driver = CUSOLVER() +function _cusolver_randomized_to_sketched(alg::CUSOLVER_Randomized) + k = alg.kwargs.k + p = alg.kwargs.p + niters = alg.kwargs.niters + return SketchedAlgorithm( + QRIteration(), + GaussianSketching(k + p; numiter = niters + 1), + truncrank(k); + driver = CUSOLVER(), + ) +end + +for f! in (:svd_trunc!, :svd_trunc_no_error!) + @eval Base.@deprecate( + $f!(A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized), + $f!(A, USVᴴ, _cusolver_randomized_to_sketched(alg)) + ) +end + +@inline function select_algorithm(::typeof(svd_trunc!), A, alg::CUSOLVER_Randomized; kwargs...) + Base.depwarn( + "`CUSOLVER_Randomized` is deprecated; use \ + `SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters+1), truncrank(k); driver=CUSOLVER())` instead.", + :select_algorithm, + ) + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return _cusolver_randomized_to_sketched(alg) +end +@inline function select_algorithm(::typeof(svd_trunc_no_error!), A, alg::CUSOLVER_Randomized; kwargs...) + return select_algorithm(svd_trunc!, A, alg; kwargs...) +end + # GLA_QRIteration SVD deprecations (eigh methods remain in the GLA extension) Base.@deprecate( svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration), diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 4077214e8..6f7e55c68 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -408,8 +408,6 @@ for more information. """ @algdef CUSOLVER_Randomized -does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true - """ CUSOLVER_Simple(; fixgauge = default_fixgauge()) diff --git a/test/decompositions/svd.jl b/test/decompositions/svd.jl index b8401d2fa..0fd3e4972 100644 --- a/test/decompositions/svd.jl +++ b/test/decompositions/svd.jl @@ -33,7 +33,7 @@ if !is_buildkite m, n = 54, 63 rtol = sqrt(TestSuite.precision(T)) # extra square root algs = [ - SketchedAlgorithm(DefaultAlgorithm(), GaussianSketching(m ÷ 2, numiter = 4), truncrank(m ÷ 4)), + SketchedAlgorithm(; sketch = GaussianSketching(m ÷ 2, numiter = 4), trunc = truncrank(m ÷ 4)), ] TestSuite.test_sketched_svd(T, (m, n), algs; rtol) TestSuite.test_sketched_svd(T, (n, m), algs) From e67940718222fcec0e0e1fa4dcc6573d641c62aa Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 8 May 2026 14:00:48 -0400 Subject: [PATCH 5/8] some more test updates --- test/decompositions/svd.jl | 2 +- test/testsuite/TestSuite.jl | 5 ++++- test/testsuite/decompositions/svd.jl | 12 ++++++++---- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/test/decompositions/svd.jl b/test/decompositions/svd.jl index 0fd3e4972..a62b68a61 100644 --- a/test/decompositions/svd.jl +++ b/test/decompositions/svd.jl @@ -43,7 +43,7 @@ if !is_buildkite for T in GenericFloats, m in (0, 54), n in (0, 37, m, 63) TestSuite.seed_rng!(123) TestSuite.test_svd(T, (m, n)) - TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) + TestSuite.test_svd_algs(T, (m, n), (QRIteration(; driver = MatrixAlgebraKit.GLA()),)) end # Diagonal: diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index a37c7b9c7..87cbaa554 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -101,7 +101,10 @@ function instantiate_rank_deficient_matrix(::Type{T}, sz; trunc = truncrank(div( return Diagonal(diag(mul!(A, V, C))) end -function instantiate_almost_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2)), atol::Real = 0, rtol::Real = precision(T)) +function instantiate_almost_rank_deficient_matrix( + T, sz; + trunc = truncrank(div(min(sz...), 2)), atol::Real = 0, rtol::Real = precision(T) + ) A = instantiate_rank_deficient_matrix(T, sz; trunc) noise = normalize(instantiate_matrix(T, sz)) A .+= max(atol, rtol * norm(A)) * noise diff --git a/test/testsuite/decompositions/svd.jl b/test/testsuite/decompositions/svd.jl index d05bf62eb..b5b5a700b 100644 --- a/test/testsuite/decompositions/svd.jl +++ b/test/testsuite/decompositions/svd.jl @@ -371,19 +371,23 @@ function test_sketched_svd( return @testset "sketched svd_trunc! algorithm $alg $summary_str" for alg in algs @assert alg isa SketchedAlgorithm "Invalid sketched algorithm type: $(typeof(alg))" - A = instantiate_rank_deficient_matrix(T, sz; alg.trunc) - A += max(atol, rtol * norm(A)) * instantiate_matrix(T, sz) + A = instantiate_almost_rank_deficient_matrix(T, sz; alg.trunc, atol, rtol) Ac = deepcopy(A) alg2 = MatrixAlgebraKit.TruncatedAlgorithm(alg) U, S, Vᴴ, ϵ = @testinferred svd_trunc(A, alg) @test Ac == A + ϵ′ = norm(A - U * S * Vᴴ) + @test ϵ′ ≈ ϵ atol = sqrt(rtol) * max(one(ϵ′), ϵ′) # comparison to 0 is hard, very imprecise calculation - U′, S′, Vᴴ′, ϵ′ = svd_trunc(A, alg2) + U′, S′, Vᴴ′ = svd_trunc_no_error(A, alg2) + + # Need gauge fixing for comparison + U, Vᴴ = MatrixAlgebraKit.gaugefix!(svd_trunc!, U, Vᴴ) + U′, Vᴴ′ = MatrixAlgebraKit.gaugefix!(svd_trunc!, U′, Vᴴ′) @test U ≈ U′ atol = atol rtol = rtol @test S ≈ S′ atol = atol rtol = rtol @test Vᴴ ≈ Vᴴ′ atol = atol rtol = rtol - @test ϵ ≈ ϵ′ atol = atol rtol = rtol end end From 213fe3c7f7a7cc24057821d62cf22e5dd9e72eff Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 8 May 2026 14:35:57 -0400 Subject: [PATCH 6/8] revert some unintended changes --- src/algorithms.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index dba747127..d8eb61d5e 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -242,7 +242,7 @@ default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK() """ abstract type TruncationStrategy end -Supertype to denote different strategies for truncated decompositions. +Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation. See also [`truncate`](@ref) """ @@ -593,7 +593,7 @@ macro check_size(x, sz, size = :size) szx = $size($x) $err = $msgstart * string(szx) * " instead of expected value " * string($sz) - (szx == $sz)::Bool || throw(DimensionMismatch($err)) + (szx == $sz) || throw(DimensionMismatch($err)) end ) end From ffbfa8553251110346f6400553110a1db266df8c Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sat, 9 May 2026 17:17:34 -0400 Subject: [PATCH 7/8] Some GPU fixes --- .../MatrixAlgebraKitCUDAExt.jl | 35 +++++------ ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 58 ++++++++++--------- src/implementations/svd.jl | 34 ++++++++++- src/yalapack.jl | 2 +- test/decompositions/svd.jl | 17 +++--- 5 files changed, 89 insertions(+), 57 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index cb1230247..1085d7636 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -52,30 +52,27 @@ end gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...) -# Sketched SVD via cuSOLVER's gesvdr kernel +# Sketched SVD via cuSOLVER's gesvdr kernel. +# The full m×m / n×n shapes of U / Vᴴ allow YACUSOLVER.gesvdr! to reuse them as cuSOLVER workspace. +# `alg` is accepted but unused: cuSOLVER's gesvdr fuses the inner SVD itself. function gesvdr!( ::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm() ) isempty(A) && return U, S, Vᴴ - m, n = size(A); minmn = min(m, n) - k = trunc.howmany - 1 ≤ k ≤ minmn || - throw(ArgumentError("trunc.howmany=$k must satisfy 1 ≤ k ≤ min(size(A))=$minmn")) - p = sketch.howmany - k - p ≥ 0 || throw( - ArgumentError( - "sketch.howmany=$(sketch.howmany) must be ≥ trunc.howmany=$k" - ) - ) - p = min(p, minmn - k - 1) - niters = sketch.numiter - 1 - - Uk = view(U, :, 1:k) - Vᴴk = view(Vᴴ, 1:k, :) - Sk = view(diagview(S), 1:k) - YACUSOLVER.gesvdr!(A, Sk, Uk, Vᴴk; k, p, niters) - return Uk, S, Vᴴk + m, n = size(A) + sketch_amount = min(sketch.howmany, m, n) + k = min(trunc.howmany, m, n) + p = max(sketch_amount - k, 0) + numiter = sketch.numiter + + V = Vᴴ # gesvdr returns V, but this has to be the same size so we will use this as workspace + + YACUSOLVER.gesvdr!(A, diagview(S), U, V; k, p, numiter) + + # Truncate requires Vᴴ, so we adjoint here + USVᴴtrunc, _ = MatrixAlgebraKit.truncate(MatrixAlgebraKit.svd_trunc!, (U, S, V'), trunc) + return USVᴴtrunc end geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) = diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 0cfa64c3c..6a3f4358f 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -266,32 +266,44 @@ for (bname, fname, elty, relty) in end end -# Wrapper for randomized SVD +# Wrapper for randomized SVD. +# Caller must supply full-size buffers: U is (m, m) and Vᴴ is (n, n); both are reused +# directly as cuSOLVER's workspace, and Vᴴ is converted in place from V to Vᴴ on the +# leading k rows after cuSOLVER returns. +# !!! Warning: this function takes in/returns V instead of Vᴴ function gesvdr!( A::StridedCuMatrix{T}, S::StridedCuVector = similar(A, real(T), min(size(A)...)), - U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), - Vᴴ::StridedCuMatrix{T} = similar(A, T, min(size(A)...), size(A, 2)); + U::StridedCuMatrix{T} = similar(A, T, size(A, 1), size(A, 1)), + V::StridedCuMatrix{T} = similar(A, T, size(A, 2), size(A, 2)); k::Int = length(S), p::Int = min(size(A)...) - k - 1, - niters::Int = 1 + numiter::Int = 1, ) where {T <: BlasFloat} - chkstride1(A, U, S, Vᴴ) + chkstride1(A, U, S, V) m, n = size(A) minmn = min(m, n) - jobu = length(U) == 0 ? 'N' : 'S' - jobv = length(Vᴴ) == 0 ? 'N' : 'S' R = eltype(S) - k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)")) - k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)")) R == real(T) || throw(ArgumentError("S does not have the matching real `eltype` of A")) - - Ṽ = similar(Vᴴ, (n, n)) - Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m)) + length(S) == minmn || + throw(DimensionMismatch("length of S ($(length(S))) must equal min(size(A)) = $minmn")) + size(U) == (m, m) || + throw(DimensionMismatch("U must have shape (m, m) = ($m, $m); got $(size(U))")) + size(V) == (n, n) || + throw(DimensionMismatch("V must have shape (n, n) = ($n, $n); got $(size(V))")) + k < minmn || + throw(DimensionMismatch("rank k ($k) must be less than min(size(A)) = $minmn")) + k + p < minmn || + throw(DimensionMismatch("k + p ($(k + p)) must be less than min(size(A)) = $minmn")) + + isempty(A) && return S, U, V + + jobu = 'S' + jobv = 'S' lda = max(1, stride(A, 2)) - ldu = max(1, stride(Ũ, 2)) - ldv = max(1, stride(Ṽ, 2)) + ldu = max(1, stride(U, 2)) + ldv = max(1, stride(V, 2)) params = cuSOLVER.CuSolverParameters() dh = cuSOLVER.dense_handle() @@ -299,8 +311,8 @@ function gesvdr!( out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) cuSOLVER.cusolverDnXgesvdr_bufferSize( - dh, params, jobu, jobv, m, n, k, p, niters, - T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + dh, params, jobu, jobv, m, n, k, p, numiter, + T, A, lda, R, S, T, U, ldu, T, V, ldv, T, out_gpu, out_cpu ) @@ -311,8 +323,8 @@ function gesvdr!( bufferSize()... ) do buffer_gpu, buffer_cpu return cuSOLVER.cusolverDnXgesvdr( - dh, params, jobu, jobv, m, n, k, p, niters, - T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + dh, params, jobu, jobv, m, n, k, p, numiter, + T, A, lda, R, S, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), dh.info @@ -321,16 +333,8 @@ function gesvdr!( flag = @allowscalar dh.info[1] cuSOLVER.chklapackerror(BlasInt(flag)) - if Ũ !== U && length(U) > 0 - U .= view(Ũ, 1:m, 1:size(U, 2)) - end - if length(Vᴴ) > 0 - Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n) - end - Ũ !== U && CUDA.unsafe_free!(Ũ) - CUDA.unsafe_free!(Ṽ) - return S, U, Vᴴ + return S, U, V end # Wrapper for general eigensolver diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f67c1cd21..5c9fe8038 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -311,6 +311,36 @@ function svd_trunc_no_error!(A::AbstractMatrix, (U, S, Vᴴ), alg::SketchedAlgor return gesvdr!(alg.driver, A, S, U, Vᴴ; alg.sketch, alg.alg, alg.trunc) end +# CUSOLVER's gesvdr kernel requires full U and Vᴴ +function initialize_output( + ::typeof(svd_trunc_no_error!), A::AbstractMatrix, + alg::SketchedAlgorithm{<:AbstractAlgorithm, <:SketchingStrategy, <:TruncationStrategy, CUSOLVER}, + ) + m, n = size(A) + minmn = min(m, n) + T = float(eltype(A)) + U = similar(A, T, (m, m)) + S = Diagonal(similar(A, real(T), (minmn,))) + Vᴴ = similar(A, T, (n, n)) + return (U, S, Vᴴ) +end + +function check_input( + ::typeof(svd_trunc_no_error!), A::AbstractMatrix, (U, S, Vᴴ), + alg::SketchedAlgorithm{<:AbstractAlgorithm, <:SketchingStrategy, <:TruncationStrategy, CUSOLVER}, + ) + m, n = size(A) + minmn = min(m, n) + @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix + @check_size(U, (m, m)) + @check_scalar(U, A) + @check_size(S, (minmn, minmn)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (n, n)) + @check_scalar(Vᴴ, A) + return nothing +end + function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::SketchedAlgorithm) U, S, Vᴴ = svd_trunc_no_error!(A, USVᴴ, alg) Na = norm(A) @@ -399,7 +429,7 @@ function _cusolver_randomized_to_sketched(alg::CUSOLVER_Randomized) niters = alg.kwargs.niters return SketchedAlgorithm( QRIteration(), - GaussianSketching(k + p; numiter = niters + 1), + GaussianSketching(k + p; numiter = niters), truncrank(k); driver = CUSOLVER(), ) @@ -415,7 +445,7 @@ end @inline function select_algorithm(::typeof(svd_trunc!), A, alg::CUSOLVER_Randomized; kwargs...) Base.depwarn( "`CUSOLVER_Randomized` is deprecated; use \ - `SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters+1), truncrank(k); driver=CUSOLVER())` instead.", + `SketchedAlgorithm(QRIteration(), GaussianSketching(k+p; numiter=niters), truncrank(k); driver=CUSOLVER())` instead.", :select_algorithm, ) isempty(kwargs) || diff --git a/src/yalapack.jl b/src/yalapack.jl index 57e974bda..7da139b0b 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2351,7 +2351,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in throw(DimensionMismatch("length mismatch between A ($n) and S ($(length(S)))")) lda = max(1, stride(A, 2)) - mv = Ref{BlasInt}() # unused + mv = Ref{BlasInt}(0) # unused by LAPACK when JOBV='V', but must satisfy MV ≥ 0 input check if jobv == 'V' if U !== A V = view(U, 1:n, 1:n) # use U as V storage diff --git a/test/decompositions/svd.jl b/test/decompositions/svd.jl index a62b68a61..f2f5acb87 100644 --- a/test/decompositions/svd.jl +++ b/test/decompositions/svd.jl @@ -36,7 +36,7 @@ if !is_buildkite SketchedAlgorithm(; sketch = GaussianSketching(m ÷ 2, numiter = 4), trunc = truncrank(m ÷ 4)), ] TestSuite.test_sketched_svd(T, (m, n), algs; rtol) - TestSuite.test_sketched_svd(T, (n, m), algs) + TestSuite.test_sketched_svd(T, (n, m), algs; rtol) end # Generic floats: @@ -57,7 +57,7 @@ end # CUDA tests # ------------ -if false # CUDA.functional() +if CUDA.functional() # LAPACK algorithms: for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) TestSuite.seed_rng!(123) @@ -66,18 +66,19 @@ if false # CUDA.functional() TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) end - # Randomized SVD: + # Sketched SVD: for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) TestSuite.seed_rng!(123) k = 5 p = min(m, n) - k - 2 p > 0 || continue - cusolver_sketch = SketchedAlgorithm( - GaussianSketching(k; oversampling = p, niters = 20), - DefaultAlgorithm(), - MatrixAlgebraKit.CUSOLVER(), + rtol = sqrt(TestSuite.precision(T)) # extra square root + cusolver_sketch = SketchedAlgorithm(; + sketch = GaussianSketching(k + p; numiter = 20), + trunc = truncrank(k), + driver = MatrixAlgebraKit.CUSOLVER(), ) - TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (cusolver_sketch,)) + TestSuite.test_sketched_svd(CuMatrix{T}, (m, n), (cusolver_sketch,); rtol) end # Diagonal: From ff439bdf8bea126128c57ac4937f87e2596b5964 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 10 May 2026 08:29:58 -0400 Subject: [PATCH 8/8] Fix method ambiguity --- src/algorithms.jl | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index d8eb61d5e..15cda5020 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -220,23 +220,37 @@ Driver to select GenericSchur.jl as the implementation strategy. struct GS <: Driver end # In order to avoid amibiguities, this method is implemented in a tiered way -# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) -# default_driver(Talg, TA) -> default_driver(TA) +# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) +# default_driver(Talg, A) -> default_driver(Talg, typeof(A)) +# default_driver(Talg, TA) -> default_driver(Talg, _unwrapped_array_type(TA)) | default_driver(TA) +# default_driver(TA) -> driver # This is to try and minimize ambiguity while allowing overloading at multiple levels @inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A)) @inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A)) -@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA) + +# Generic 2-arg fallback: if `TA` is a supported wrapper type, recurse on the +# unwrapped storage type (so algorithm-specialized methods on the parent array +# type still apply). Otherwise drop the algorithm and fall through to the +# array-only dispatch. +@inline function default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA <: AbstractArray} + UA = _unwrapped_array_type(TA) + return UA === TA ? default_driver(TA) : default_driver(Alg, UA) +end # defaults default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK() -# wrapper types -@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) -@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) +# wrapper types (1-arg form, reached via the generic 2-arg fallback) @inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A) @inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A) +# Internal helper: strip supported wrapper types to the underlying storage +# array type. Add a new method here when introducing additional wrappers. +@inline _unwrapped_array_type(::Type{TA}) where {TA <: AbstractArray} = TA +@inline _unwrapped_array_type(::Type{<:SubArray{T, N, A}}) where {T, N, A} = _unwrapped_array_type(A) +@inline _unwrapped_array_type(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = _unwrapped_array_type(A) + # Truncation strategy # ------------------- """