Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = ["Jutho Haegeman <jutho.haegeman@ugent.be>, Lukas Devos, Katharine Hya
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down
30 changes: 26 additions & 4 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, TruncationByOrder, AbstractAlgorithm
using MatrixAlgebraKit: GaussianSketching, SketchingStrategy, SketchedAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
import MatrixAlgebraKit: _sylvester, svd_rank
using CUDA, CUDA.cuBLAS
using CUDA: i32
using LinearAlgebra
Expand All @@ -17,6 +18,7 @@ using LinearAlgebra: BlasFloat
include("yacusolver.jl")

MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
MatrixAlgebraKit.default_driver(::Type{<:SketchedAlgorithm}, ::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
Expand Down Expand Up @@ -50,8 +52,28 @@ end
gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)

_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)
# Sketched SVD via cuSOLVER's gesvdr kernel.
# The full m×m / n×n shapes of U / Vᴴ allow YACUSOLVER.gesvdr! to reuse them as cuSOLVER workspace.
# `alg` is accepted but unused: cuSOLVER's gesvdr fuses the inner SVD itself.
function gesvdr!(
::CUSOLVER, A::StridedCuMatrix, S, U::StridedCuMatrix, Vᴴ::StridedCuMatrix;
sketch::GaussianSketching, trunc::TruncationByOrder, alg::AbstractAlgorithm = DefaultAlgorithm()
)
isempty(A) && return U, S, Vᴴ
m, n = size(A)
sketch_amount = min(sketch.howmany, m, n)
k = min(trunc.howmany, m, n)
p = max(sketch_amount - k, 0)
numiter = sketch.numiter

V = Vᴴ # gesvdr returns V, but this has to be the same size so we will use this as workspace

YACUSOLVER.gesvdr!(A, diagview(S), U, V; k, p, numiter)

# Truncate requires Vᴴ, so we adjoint here
USVᴴtrunc, _ = MatrixAlgebraKit.truncate(MatrixAlgebraKit.svd_trunc!, (U, S, V'), trunc)
return USVᴴtrunc
end

geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, Dd, V)
Expand Down
58 changes: 31 additions & 27 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,41 +266,53 @@ for (bname, fname, elty, relty) in
end
end

# Wrapper for randomized SVD
# Wrapper for randomized SVD.
# Caller must supply full-size buffers: U is (m, m) and Vᴴ is (n, n); both are reused
# directly as cuSOLVER's workspace, and Vᴴ is converted in place from V to Vᴴ on the
# leading k rows after cuSOLVER returns.
# !!! Warning: this function takes in/returns V instead of Vᴴ
function gesvdr!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Vᴴ::StridedCuMatrix{T} = similar(A, T, min(size(A)...), size(A, 2));
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), size(A, 1)),
V::StridedCuMatrix{T} = similar(A, T, size(A, 2), size(A, 2));
k::Int = length(S),
p::Int = min(size(A)...) - k - 1,
niters::Int = 1
numiter::Int = 1,
) where {T <: BlasFloat}
chkstride1(A, U, S, Vᴴ)
chkstride1(A, U, S, V)
m, n = size(A)
minmn = min(m, n)
jobu = length(U) == 0 ? 'N' : 'S'
jobv = length(Vᴴ) == 0 ? 'N' : 'S'
R = eltype(S)
k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)"))
k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)"))
R == real(T) ||
throw(ArgumentError("S does not have the matching real `eltype` of A"))

Ṽ = similar(Vᴴ, (n, n))
Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m))
length(S) == minmn ||
throw(DimensionMismatch("length of S ($(length(S))) must equal min(size(A)) = $minmn"))
size(U) == (m, m) ||
throw(DimensionMismatch("U must have shape (m, m) = ($m, $m); got $(size(U))"))
size(V) == (n, n) ||
throw(DimensionMismatch("V must have shape (n, n) = ($n, $n); got $(size(V))"))
k < minmn ||
throw(DimensionMismatch("rank k ($k) must be less than min(size(A)) = $minmn"))
k + p < minmn ||
throw(DimensionMismatch("k + p ($(k + p)) must be less than min(size(A)) = $minmn"))

isempty(A) && return S, U, V

jobu = 'S'
jobv = 'S'
lda = max(1, stride(A, 2))
ldu = max(1, stride(, 2))
ldv = max(1, stride(, 2))
ldu = max(1, stride(U, 2))
ldv = max(1, stride(V, 2))
params = cuSOLVER.CuSolverParameters()
dh = cuSOLVER.dense_handle()

function bufferSize()
out_cpu = Ref{Csize_t}(0)
out_gpu = Ref{Csize_t}(0)
cuSOLVER.cusolverDnXgesvdr_bufferSize(
dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, , ldu, T, , ldv,
dh, params, jobu, jobv, m, n, k, p, numiter,
T, A, lda, R, S, T, U, ldu, T, V, ldv,
T, out_gpu, out_cpu
)

Expand All @@ -311,8 +323,8 @@ function gesvdr!(
bufferSize()...
) do buffer_gpu, buffer_cpu
return cuSOLVER.cusolverDnXgesvdr(
dh, params, jobu, jobv, m, n, k, p, niters,
T, A, lda, R, S, T, , ldu, T, , ldv,
dh, params, jobu, jobv, m, n, k, p, numiter,
T, A, lda, R, S, T, U, ldu, T, V, ldv,
T, buffer_gpu, sizeof(buffer_gpu),
buffer_cpu, sizeof(buffer_cpu),
dh.info
Expand All @@ -321,16 +333,8 @@ function gesvdr!(

flag = @allowscalar dh.info[1]
cuSOLVER.chklapackerror(BlasInt(flag))
if Ũ !== U && length(U) > 0
U .= view(Ũ, 1:m, 1:size(U, 2))
end
if length(Vᴴ) > 0
Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n)
end
Ũ !== U && CUDA.unsafe_free!(Ũ)
CUDA.unsafe_free!(Ṽ)

return S, U, Vᴴ
return S, U, V
end

# Wrapper for general eigensolver
Expand Down
8 changes: 8 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

using Random: Random

export isisometric, isunitary, ishermitian, isantihermitian
export diagview, diagonal

Expand All @@ -30,6 +32,8 @@ export left_polar, right_polar
export left_polar!, right_polar!
export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!
export left_sketch, right_sketch
export left_sketch!, right_sketch!

export Householder, Native_HouseholderQR, Native_HouseholderLQ
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
Expand All @@ -50,6 +54,8 @@ export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi,

export notrunc, truncrank, trunctol, truncerror, truncfilter

export SketchedAlgorithm, SketchingStrategy, GaussianSketching

@static if VERSION >= v"1.11.0-DEV.469"
eval(
Expr(
Expand Down Expand Up @@ -101,6 +107,7 @@ include("interface/truncation.jl")
include("interface/qr.jl")
include("interface/lq.jl")
include("interface/svd.jl")
include("interface/sketching.jl")
include("interface/eig.jl")
include("interface/eigh.jl")
include("interface/gen_eig.jl")
Expand All @@ -113,6 +120,7 @@ include("implementations/truncation.jl")
include("implementations/qr.jl")
include("implementations/lq.jl")
include("implementations/svd.jl")
include("implementations/sketching.jl")
include("implementations/eig.jl")
include("implementations/eigh.jl")
include("implementations/gen_eig.jl")
Expand Down
92 changes: 82 additions & 10 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,37 @@ Driver to select GenericSchur.jl as the implementation strategy.
struct GS <: Driver end

# In order to avoid amibiguities, this method is implemented in a tiered way
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, TA) -> default_driver(TA)
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, A) -> default_driver(Talg, typeof(A))
# default_driver(Talg, TA) -> default_driver(Talg, _unwrapped_array_type(TA)) | default_driver(TA)
# default_driver(TA) -> driver
# This is to try and minimize ambiguity while allowing overloading at multiple levels
@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A))
@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A))
@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA)

# Generic 2-arg fallback: if `TA` is a supported wrapper type, recurse on the
# unwrapped storage type (so algorithm-specialized methods on the parent array
# type still apply). Otherwise drop the algorithm and fall through to the
# array-only dispatch.
@inline function default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA <: AbstractArray}
UA = _unwrapped_array_type(TA)
return UA === TA ? default_driver(TA) : default_driver(Alg, UA)
end

# defaults
default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback
default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK()

# wrapper types
@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
# wrapper types (1-arg form, reached via the generic 2-arg fallback)
@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A)
@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A)

# Internal helper: strip supported wrapper types to the underlying storage
# array type. Add a new method here when introducing additional wrappers.
@inline _unwrapped_array_type(::Type{TA}) where {TA <: AbstractArray} = TA
@inline _unwrapped_array_type(::Type{<:SubArray{T, N, A}}) where {T, N, A} = _unwrapped_array_type(A)
@inline _unwrapped_array_type(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = _unwrapped_array_type(A)

# Truncation strategy
# -------------------
"""
Expand All @@ -248,6 +262,15 @@ See also [`truncate`](@ref)
"""
abstract type TruncationStrategy end

"""
abstract type SketchingStrategy <: AbstractAlgorithm end

Supertype to denote different sketching strategies, used both as standalone algorithms for
[`left_sketch!`](@ref) and [`right_sketch!`](@ref) and as the `sketch` field of a
[`SketchedAlgorithm`](@ref) for self-truncating SVD.
"""
abstract type SketchingStrategy <: AbstractAlgorithm end

@doc """
MatrixAlgebraKit.select_truncation(trunc)

Expand Down Expand Up @@ -283,7 +306,26 @@ function select_null_truncation(trunc)
elseif trunc isa TruncationStrategy
return trunc
else
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
throw(ArgumentError("Unknown truncation strategy: $trunc"))
end
end

@doc """
MatrixAlgebraKit.select_sketching(A, sketch)

Construct a [`SketchingStrategy`](@ref) for `A` from the given `NamedTuple` of keywords or input strategy `sketch`.
""" select_sketching

@inline select_sketching(A, sketch) = select_sketching(typeof(A), sketch)
@inline function select_sketching(::Type{A}, sketch) where {A}
if isnothing(sketch)
return nothing
elseif sketch isa SketchingStrategy
return sketch
elseif sketch isa NamedTuple
return select_algorithm(left_sketch!, A; sketch...)
else
throw(ArgumentError("Unknown sketching strategy: $sketch"))
end
end

Expand Down Expand Up @@ -317,17 +359,47 @@ See also [`findtruncated`](@ref) and [`findtruncated_svd`](@ref) for determining
function truncate end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy)

Generic wrapper type for algorithms that consist of first using `alg`, followed by a
truncation through `trunc`.
"""
struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
struct TruncatedAlgorithm{A <: AbstractAlgorithm, T <: TruncationStrategy} <: AbstractAlgorithm
alg::A
trunc::T
end

"""
SketchedAlgorithm(;
alg::AbstractAlgorithm, sketch::SketchingStrategy,
trunc::TruncationStrategy, driver::Driver = DefaultDriver()
)

Generic wrapper type for self-truncating algorithms that produce an approximate low-rank
factorization by first applying a sketching operation specified by `sketch`, then computing
a small dense decomposition of the projected matrix using `alg`. The `driver` selects the
backend implementing the sketched factorization (e.g. `Native()` for the generic
sketch-then-decompose pipeline, `CUSOLVER()` for the fused `gesvdr` kernel).
"""
@kwdef struct SketchedAlgorithm{
A <: AbstractAlgorithm, S <: SketchingStrategy,
T <: TruncationStrategy, D <: Driver,
} <: AbstractAlgorithm
alg::A = DefaultAlgorithm()
sketch::S
trunc::T = notrunc()
driver::D = DefaultDriver()
end

# utility conversion constructor
TruncatedAlgorithm(alg::SketchedAlgorithm) = TruncatedAlgorithm(alg.alg, alg.trunc)

does_truncate(::TruncatedAlgorithm) = true
does_truncate(::SketchedAlgorithm) = true

truncated_algorithm(alg::AbstractAlgorithm, trunc::TruncationStrategy, sketch = nothing) =
isnothing(sketch) ? TruncatedAlgorithm(alg, trunc) : SketchedAlgorithm(; alg, sketch, trunc)


# Utility macros
# --------------
Expand Down Expand Up @@ -535,7 +607,7 @@ macro check_size(x, sz, size = :size)
szx = $size($x)
$err = $msgstart * string(szx) * " instead of expected value " *
string($sz)
szx == $sz || throw(DimensionMismatch($err))
(szx == $sz) || throw(DimensionMismatch($err))
end
)
end
Expand Down
6 changes: 3 additions & 3 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ function right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm})
return Nᴴ
end

# randomized algorithms don't currently work for smallest values:
left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) =
# randomized (sketched) algorithms don't currently work for smallest values:
left_null!(A, N, alg::LeftNullViaSVD{<:SketchedAlgorithm}) =
throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet"))
right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) =
right_null!(A, Nᴴ, alg::RightNullViaSVD{<:SketchedAlgorithm}) =
throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet"))
Loading
Loading