From 00d105497ac7b5f21bcd9012e620239ea19aee60 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 28 Apr 2026 11:37:23 +0200 Subject: [PATCH 1/3] Mooncake fwd mode support --- .../TensorOperationsEnzymeExt.jl | 119 ++++++++++++++++++ .../TensorOperationsMooncakeExt.jl | 119 +++++++++++++++++- test/enzyme.jl | 87 +++++++++++++ test/mooncake.jl | 31 +++-- 4 files changed, 336 insertions(+), 20 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 9b40eb74..3dc23e86 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -14,6 +14,7 @@ using Enzyme.EnzymeCore: EnzymeRules @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true +@inline EnzymeRules.inactive_type(v::Type{<:IndexTuple}) = true function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, @@ -126,6 +127,46 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Annotation{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Annotation{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val) + + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + @. C_dC.dval += β_dβ.dval * C_dC.val + end + if !isa(α_dα, Const) + tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α_dα.dval, One(), ba...) + end + if !isa(A_dA, Const) + tensorcontract!(C_dC.dval, A_dA.dval, pA, conjA, B_dB.val, pB, conjB, pAB, α, One(), ba...) + end + if !isa(B_dB, Const) + tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.dval, pB, conjB, pAB, α, One(), ba...):Zero() + end + end + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α, β, ba...) + return C_dC +end + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, ::Annotation{typeof(tensoradd!)}, @@ -198,6 +239,45 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + pA = pA_dpA.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + + # D = α * A + β *C + + # dD = dα * A + α * dA + β dC + dβ * C + + # dC′ = β dC + dβ * C + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + @. C_dC.dval += β_dβ.dval * C_dC.val + end + if !isa(A_dA, Const) + TensorOperations.tensoradd!(C_dC.dval, A_dA.dval, pA, conjA, α, One(), ba...) + end + if !isa(α_dα, Const) + TensorOperations.tensoradd!(C_dC.dval, A_dA.val, pA, conjA, α_dα.dval, One(), ba...) + end + end + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA, conjA, α, β, ba...) + return C_dC +end + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, ::Annotation{typeof(tensortrace!)}, @@ -273,4 +353,43 @@ function EnzymeRules.reverse( return nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end +function EnzymeRules.forward( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Annotation{<:Index2Tuple}, + q_dq::Annotation{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + if !isa(C_dC, Const) + scale!(C_dC.dval, β) + if !isa(β_dβ, Const) + @. C_dC, dval += β_dβ.dval * C_dC.val + end + if !isa(α_dα, Const) + TensorOperations.tensortrace!(C_dC.dval, A_dA.val, p, q, conjA, α_dα.dval, One(), ba...) + end + if !isa(A_dA, Const) + TensorOperations.tensortrace!(C_dC.dval, A_dA.dval, p, q, conjA, α, One(), ba...) + end + end + # D = α * tr(A) + β * C + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p, q, conjA, α, β, ba...) + return C_dC +end + end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 8e302ca5..33f03ac2 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -7,7 +7,7 @@ using TensorOperations using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! -using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent +using Mooncake: ReverseMode, DefaultCtx, Dual, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools Mooncake.tangent_type(::Type{Index2Tuple}) = Mooncake.NoTangent @@ -29,7 +29,7 @@ Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensorfree!), Any} Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensoralloc), Any, Any, Any, Any} -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensorcontract!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -67,7 +67,47 @@ function Mooncake.rrule!!( return C_dC, contract_pb end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}} +function Mooncake.frule!!( + ::Dual{typeof(tensorcontract!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + pA_dpA::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + B_dB::Dual{<:AbstractArray{TB}}, + pB_dpB::Dual{<:Index2Tuple}, + conjB_dconjB::Dual{Bool}, + pAB_dpAB::Dual{<:Index2Tuple}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + pA = primal(pA_dpA) + pB = primal(pB_dpB) + pAB = primal(pAB_dpAB) + conjA = primal(conjA_dconjA) + conjB = primal(conjB_dconjB) + α, dα = Mooncake.extract(α_dα) + β, dβ = Mooncake.extract(β_dβ) + ba = primal.(ba_dba) + + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + @. dC += dβ * C + end + if !isa(dα, Mooncake.NoTangent) + tensorcontract!(dC, A, pA, conjA, B, pB, conjB, pAB, dα, One(), ba...) + end + tensorcontract!(dC, dA, pA, conjA, B, pB, conjB, pAB, α, One(), ba...) + tensorcontract!(dC, A, pA, conjA, dB, pB, conjB, pAB, α, One(), ba...) + TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + return C_dC +end + +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensoradd!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -97,7 +137,43 @@ function Mooncake.rrule!!( return C_dC, add_pb end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}} +function Mooncake.frule!!( + ::Dual{typeof(tensoradd!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + pA_dpA::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + pA = primal(pA_dpA) + conjA = primal(conjA_dconjA) + α = primal(α_dα) + dα = tangent(α_dα) + β = primal(β_dβ) + dβ = tangent(β_dβ) + ba = primal.(ba_dba) + # D = α * A + β *C + + # dD = dα * A + α * dA + β dC + dβ * C + + # dC′ = β dC + dβ * C + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + @. dC += dβ * C + end + TensorOperations.tensoradd!(dC, dA, pA, conjA, α, One(), ba...) + if !isa(dα, Mooncake.NoTangent) + TensorOperations.tensoradd!(dC, A, pA, conjA, dα, One(), ba...) + end + TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) + return C_dC +end + +Mooncake.@is_primitive DefaultCtx Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}} function Mooncake.rrule!!( ::CoDual{typeof(tensortrace!)}, C_dC::CoDual{<:AbstractArray{TC}}, @@ -129,4 +205,39 @@ function Mooncake.rrule!!( return C_dC, trace_pb end +function Mooncake.frule!!( + ::Dual{typeof(tensortrace!)}, + C_dC::Dual{<:AbstractArray{TC}}, + A_dA::Dual{<:AbstractArray{TA}}, + p_dp::Dual{<:Index2Tuple}, + q_dq::Dual{<:Index2Tuple}, + conjA_dconjA::Dual{Bool}, + α_dα::Dual{Tα}, + β_dβ::Dual{Tβ}, + ba_dba::Dual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + C, dC = arrayify(C_dC) + A, dA = arrayify(A_dA) + p = primal(p_dp) + q = primal(q_dq) + conjA = primal(conjA_dconjA) + α = primal(α_dα) + dα = tangent(α_dα) + β = primal(β_dβ) + dβ = tangent(β_dβ) + ba = primal.(ba_dba) + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + scale!(dC, β) + if !isa(dβ, Mooncake.NoTangent) + @. dC += dβ * C + end + if !isa(dα, Mooncake.NoTangent) + TensorOperations.tensortrace!(dC, A, p, q, conjA, dα, One(), ba...) + end + TensorOperations.tensortrace!(dC, dA, p, q, conjA, α, One(), ba...) + TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) + return C_dC +end + end diff --git a/test/enzyme.jl b/test/enzyme.jl index 5a4659f4..16612c8d 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -57,6 +57,36 @@ is_ci = get(ENV, "CI", "false") == "true" end end end + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -104,6 +134,34 @@ end end end end + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -153,6 +211,34 @@ end end end end + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Duplicated, Const) + else + (Duplicated,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_forward(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end + end end end @@ -163,4 +249,5 @@ end C = Array{T, 0}(undef, ()) fill!(C, rand(T)) test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) + test_forward(tensorscalar, Duplicated, (C, Duplicated); atol, rtol) end diff --git a/test/mooncake.jl b/test/mooncake.jl index e4bdcc69..342867f0 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -4,7 +4,6 @@ using Test using Mooncake using Random -mode = Mooncake.ReverseMode rng = Random.default_rng() is_primitive = false @@ -25,11 +24,11 @@ is_primitive = false A = rand(T₁, (2, 3, 4, 2, 5)) C = rand(T₂, size.(Ref(A), p[1])) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, is_primitive) end end @@ -48,11 +47,11 @@ end A = rand(T₁, (2, 3, 4, 2, 1)) C = rand(T₂, size.(Ref(A), pA[1])) @testset for α in (Zero(), rand(T)), β in (Zero(), rand(T)) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, is_primitive) end end @@ -77,20 +76,20 @@ end C = rand(T, (5, 2, 3, 3)) @testset for α in (Zero(), randn(T)), β in (Zero(), randn(T)) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, is_primitive) Mooncake.TestUtils.test_rule( rng, tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); - atol, rtol, mode, is_primitive + atol, rtol, is_primitive ) Mooncake.TestUtils.test_rule( rng, tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); - atol, rtol, mode, is_primitive + atol, rtol, is_primitive ) end end @@ -101,5 +100,5 @@ end C = Array{T, 0}(undef, ()) fill!(C, rand(T)) - Mooncake.TestUtils.test_rule(rng, tensorscalar, C; atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensorscalar, C; atol, rtol, is_primitive) end From 0b2ad05053179f009332d3b35b0205ce201dbaa4 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 28 Apr 2026 16:27:53 +0200 Subject: [PATCH 2/3] Typo --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 3dc23e86..b5dcf2b4 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -160,7 +160,7 @@ function EnzymeRules.forward( tensorcontract!(C_dC.dval, A_dA.dval, pA, conjA, B_dB.val, pB, conjB, pAB, α, One(), ba...) end if !isa(B_dB, Const) - tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.dval, pB, conjB, pAB, α, One(), ba...):Zero() + tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.dval, pB, conjB, pAB, α, One(), ba...) end end TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α, β, ba...) From 8ef24cf717771b19b6866d62a32b53a742e8a904 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 29 Apr 2026 11:01:33 +0200 Subject: [PATCH 3/3] Use add! rather than broadcasting Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 6 +++--- .../TensorOperationsMooncakeExt.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index b5dcf2b4..2eabbcd5 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -151,7 +151,7 @@ function EnzymeRules.forward( if !isa(C_dC, Const) scale!(C_dC.dval, β) if !isa(β_dβ, Const) - @. C_dC.dval += β_dβ.dval * C_dC.val + add!(C_dC.dval, C_dC.val, β_dβ.dval) end if !isa(α_dα, Const) tensorcontract!(C_dC.dval, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α_dα.dval, One(), ba...) @@ -265,7 +265,7 @@ function EnzymeRules.forward( if !isa(C_dC, Const) scale!(C_dC.dval, β) if !isa(β_dβ, Const) - @. C_dC.dval += β_dβ.dval * C_dC.val + add!(C_dC.dval, C_dC.val, β_dβ.dval) end if !isa(A_dA, Const) TensorOperations.tensoradd!(C_dC.dval, A_dA.dval, pA, conjA, α, One(), ba...) @@ -378,7 +378,7 @@ function EnzymeRules.forward( if !isa(C_dC, Const) scale!(C_dC.dval, β) if !isa(β_dβ, Const) - @. C_dC, dval += β_dβ.dval * C_dC.val + add!(C_dC.dval, C_dC.val, β_dβ.dval) end if !isa(α_dα, Const) TensorOperations.tensortrace!(C_dC.dval, A_dA.val, p, q, conjA, α_dα.dval, One(), ba...) diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 33f03ac2..57acd534 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -96,7 +96,7 @@ function Mooncake.frule!!( # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α scale!(dC, β) if !isa(dβ, Mooncake.NoTangent) - @. dC += dβ * C + add!(dC, C, dβ) end if !isa(dα, Mooncake.NoTangent) tensorcontract!(dC, A, pA, conjA, B, pB, conjB, pAB, dα, One(), ba...) @@ -163,7 +163,7 @@ function Mooncake.frule!!( # dC′ = β dC + dβ * C scale!(dC, β) if !isa(dβ, Mooncake.NoTangent) - @. dC += dβ * C + add!(dC, C, dβ) end TensorOperations.tensoradd!(dC, dA, pA, conjA, α, One(), ba...) if !isa(dα, Mooncake.NoTangent) @@ -230,7 +230,7 @@ function Mooncake.frule!!( # dC1 = dβ * C + β * dC scale!(dC, β) if !isa(dβ, Mooncake.NoTangent) - @. dC += dβ * C + add!(dC, C, dβ) end if !isa(dα, Mooncake.NoTangent) TensorOperations.tensortrace!(dC, A, p, q, conjA, dα, One(), ba...)