diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 9b40eb74..2eabbcd5 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -14,6 +14,7 @@ using Enzyme.EnzymeCore: EnzymeRules @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true +@inline EnzymeRules.inactive_type(v::Type{<:IndexTuple}) = true function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, @@ -126,6 +127,46 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Annotation{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Annotation{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val) + + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + add!(C_dC.dval, C_dC.val, β_dβ.dval) + end + if !isa(α_dα, Const) + tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α_dα.dval, One(), ba...) + end + if !isa(A_dA, Const) + tensorcontract!(C_dC.dval, A_dA.dval, pA, conjA, B_dB.val, pB, conjB, pAB, α, One(), ba...) + end + if !isa(B_dB, Const) + tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.dval, pB, conjB, pAB, α, One(), ba...) + end + end + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α, β, ba...) + return C_dC +end + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, ::Annotation{typeof(tensoradd!)}, @@ -198,6 +239,45 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + pA = pA_dpA.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + + # D = α * A + β *C + + # dD = dα * A + α * dA + β dC + dβ * C + + # dC′ = β dC + dβ * C + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + add!(C_dC.dval, C_dC.val, β_dβ.dval) + end + if !isa(A_dA, Const) + TensorOperations.tensoradd!(C_dC.dval, A_dA.dval, pA, conjA, α, One(), ba...) + end + if !isa(α_dα, Const) + TensorOperations.tensoradd!(C_dC.dval, A_dA.val, pA, conjA, α_dα.dval, One(), ba...) + end + end + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA, conjA, α, β, ba...) + return C_dC +end + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, ::Annotation{typeof(tensortrace!)}, @@ -273,4 +353,43 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Annotation{<:Index2Tuple}, + q_dq::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + add!(C_dC.dval, C_dC.val, β_dβ.dval) + end + if !isa(α_dα, Const) + TensorOperations.tensortrace!(C_dC.dval, A_dA.val, p, q, conjA, α_dα.dval, One(), ba...) + end + if !isa(A_dA, Const) + TensorOperations.tensortrace!(C_dC.dval, A_dA.dval, p, q, conjA, α, One(), ba...) + end + end + # D = α * tr(A) + β * C + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p, q, conjA, α, β, ba...) + return C_dC +end + end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 8e302ca5..57acd534 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -7,7 +7,7 @@ using TensorOperations using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! -using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent +using Mooncake: ReverseMode, DefaultCtx, Dual, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools Mooncake.tangent_type(::Type{Index2Tuple}) = Mooncake.NoTangent @@ -29,7 +29,7 @@ Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensorfree!), Any} Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensoralloc), Any, Any, Any, Any} -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensorcontract!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -67,7 +67,47 @@ function Mooncake.rrule!!( return C_dC, contract_pb end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}} +function Mooncake.frule!!( + ::Dual{typeof(tensorcontract!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + pA_dpA::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + B_dB::Dual{<:AbstractArray{TB}}, + pB_dpB::Dual{<:Index2Tuple}, + conjB_dconjB::Dual{Bool}, + pAB_dpAB::Dual{<:Index2Tuple}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + pA = primal(pA_dpA) + pB = primal(pB_dpB) + pAB = primal(pAB_dpAB) + conjA = primal(conjA_dconjA) + conjB = primal(conjB_dconjB) + α, dα = Mooncake.extract(α_dα) + β, dβ = Mooncake.extract(β_dβ) + ba = primal.(ba_dba) + + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + add!(dC, C, dβ) + end + if !isa(dα, Mooncake.NoTangent) + tensorcontract!(dC, A, pA, conjA, B, pB, conjB, pAB, dα, One(), ba...) + end + tensorcontract!(dC, dA, pA, conjA, B, pB, conjB, pAB, α, One(), ba...) + tensorcontract!(dC, A, pA, conjA, dB, pB, conjB, pAB, α, One(), ba...) + TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + return C_dC +end + +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensoradd!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -97,7 +137,43 @@ function Mooncake.rrule!!( return C_dC, add_pb end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}} +function Mooncake.frule!!( + ::Dual{typeof(tensoradd!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + pA_dpA::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + pA = primal(pA_dpA) + conjA = primal(conjA_dconjA) + α = primal(α_dα) + dα = tangent(α_dα) + β = primal(β_dβ) + dβ = tangent(β_dβ) + ba = primal.(ba_dba) + # D = α * A + β *C + + # dD = dα * A + α * dA + β dC + dβ * C + + # dC′ = β dC + dβ * C + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + add!(dC, C, dβ) + end + TensorOperations.tensoradd!(dC, dA, pA, conjA, α, One(), ba...) + if !isa(dα, Mooncake.NoTangent) + TensorOperations.tensoradd!(dC, A, pA, conjA, dα, One(), ba...) + end + TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) + return C_dC +end + +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensortrace!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -129,4 +205,39 @@ function Mooncake.rrule!!( return C_dC, trace_pb end +function Mooncake.frule!!( + ::Dual{typeof(tensortrace!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + p_dp::Dual{<:Index2Tuple}, + q_dq::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + p = primal(p_dp) + q = primal(q_dq) + conjA = primal(conjA_dconjA) + α = primal(α_dα) + dα = tangent(α_dα) + β = primal(β_dβ) + dβ = tangent(β_dβ) + ba = primal.(ba_dba) + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + add!(dC, C, dβ) + end + if !isa(dα, Mooncake.NoTangent) + TensorOperations.tensortrace!(dC, A, p, q, conjA, dα, One(), ba...) + end + TensorOperations.tensortrace!(dC, dA, p, q, conjA, α, One(), ba...) + TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) + return C_dC +end + end diff --git a/test/enzyme.jl b/test/enzyme.jl index 5a4659f4..16612c8d 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -57,6 +57,36 @@ is_ci = get(ENV, "CI", "false") == "true" end end end + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -104,6 +134,34 @@ end end end end + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -153,6 +211,34 @@ end end end end + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -163,4 +249,5 @@ end C = Array{T, 0}(undef, ()) fill!(C, rand(T)) test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) + test_forward(tensorscalar, Duplicated, (C, Duplicated); atol, rtol) end diff --git a/test/mooncake.jl b/test/mooncake.jl index e4bdcc69..342867f0 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -4,7 +4,6 @@ using Test using Mooncake using Random -mode = Mooncake.ReverseMode rng = Random.default_rng() is_primitive = false @@ -25,11 +24,11 @@ is_primitive = false A = rand(T₁, (2, 3, 4, 2, 5)) C = rand(T₂, size.(Ref(A), p[1])) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, is_primitive) end end @@ -48,11 +47,11 @@ end A = rand(T₁, (2, 3, 4, 2, 1)) C = rand(T₂, size.(Ref(A), pA[1])) @testset for α in (Zero(), rand(T)), β in (Zero(), rand(T)) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, is_primitive) end end @@ -77,20 +76,20 @@ end C = rand(T, (5, 2, 3, 3)) @testset for α in (Zero(), randn(T)), β in (Zero(), randn(T)) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, is_primitive) Mooncake.TestUtils.test_rule( rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); - atol, rtol, mode, is_primitive + atol, rtol, is_primitive ) Mooncake.TestUtils.test_rule( rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); - atol, rtol, mode, is_primitive + atol, rtol, is_primitive ) end end @@ -101,5 +100,5 @@ end C = Array{T, 0}(undef, ()) fill!(C, rand(T)) - Mooncake.TestUtils.test_rule(rng, tensorscalar, C; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorscalar, C; atol, rtol, is_primitive) end