Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 19 additions & 37 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,47 +357,29 @@ function _infer_sizes(
continue
end
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
if op == :+ || op == :-
# Broadcasted +/- preserves shape
_copy_size!(sizes, k, children_arr[first(children_indices)])
elseif op == :^
# Broadcasted ^ with scalar exponent preserves base shape
_copy_size!(sizes, k, children_arr[first(children_indices)])
elseif op == :*
# TODO assert compatible sizes and all ndims should be 0 or 2
if op == :+ || op == :- || op == :*
# Broadcasted +/-/* takes the largest child's shape (for
# scalar+matrix that's the matrix; for matrix+matrix we
# currently assume they match and pick the first).
first_matrix = findfirst(children_indices) do i
return !iszero(sizes.ndims[children_arr[i]])
end
if !isnothing(first_matrix)
if sizes.ndims[children_arr[first(children_indices)]] == 0
_add_size!(sizes, k, (1, 1))
continue
else
if sizes.ndims[children_arr[first(children_indices)]] ==
1
nb_cols = 1
else
nb_cols = _size(
sizes,
children_arr[first(children_indices)],
1,
)
end
_add_size!(
sizes,
k,
(
_size(
sizes,
children_arr[first(children_indices)],
1,
),
nb_cols,
),
)
continue
end
if isnothing(first_matrix)
_copy_size!(
sizes,
k,
children_arr[first(children_indices)],
)
else
_copy_size!(
sizes,
k,
children_arr[children_indices[first_matrix]],
)
end
elseif op == :^
# Broadcasted ^ with scalar exponent preserves base shape
_copy_size!(sizes, k, children_arr[first(children_indices)])
end
elseif node.type == NODE_CALL_UNIVARIATE
if !(
Expand Down
4 changes: 2 additions & 2 deletions test/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ function test_objective_broadcasted_product()
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 2, 1, 0, 0, 1, 0, 0]
@test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0]
@test sizes.size_offset == [0, 2, 1, 0, 0, 0, 0, 0]
@test sizes.size == [2, 2, 2, 1]
@test sizes.size == [2, 2, 2]
@test sizes.storage_offset == [0, 1, 3, 5, 6, 7, 9, 10, 11]
x1 = 1.0
x2 = 2.0
Expand Down
89 changes: 89 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,95 @@ function test_moi_function()
return
end

# Build an evaluator for `expr` with `vars` as the variable order. Returns the
# inferred sizes plus the evaluated objective and gradient at `x`.
function _build_and_eval(expr, vars, x)
mode = ArrayDiff.Mode()
ad = ArrayDiff.model(mode)
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(expr))
evaluator = MOI.Nonlinear.Evaluator(ad, mode, JuMP.index.(vars))
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
val = MOI.eval_objective(evaluator, x)
g = zero(x)
MOI.eval_objective_gradient(evaluator, g, x)
return sizes, val, g
end

# The previous size-inference code special-cased broadcasted `+/-/*` with a
# `nb_cols` formula that happened to match for square matrices. A 2x3 input
# shape exercises the bug: with the old code the result would be reported as
# (2, 2) instead of (2, 3), and `eval_objective` would read past the tape.
function test_broadcast_nonsquare_matrix()
model = Model()
@variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables)
Y = [10.0 20.0 30.0; 40.0 50.0 60.0]
x = Float64.(collect(1:6))
W_val = reshape(x, 2, 3)
@testset "$(op)" for (op, expr, ref_mat) in [
(:+, LinearAlgebra.norm(W .+ Y), W_val .+ Y),
(:-, LinearAlgebra.norm(W .- Y), W_val .- Y),
(:*, LinearAlgebra.norm(W .* W), W_val .* W_val),
]
sizes, val, g = _build_and_eval(expr, JuMP.all_variables(model), x)
# Outer norm scalar, then the broadcasted op produces a 2x3 matrix,
# then the two 2x3 leaves: 4 nodes, three of them ndims=2 with size
# (2, 3). The old bug would report (2, 2) for the broadcast node.
@test sizes.ndims == [0, 2, 2, 2]
@test sizes.size == [2, 3, 2, 3, 2, 3]
@test sizes.size_offset == [0, 4, 2, 0]
@test sizes.storage_offset == [0, 1, 7, 13, 19]
@test val ≈ LinearAlgebra.norm(ref_mat)
ref_g = if op == :+
vec(W_val .+ Y) ./ LinearAlgebra.norm(ref_mat)
elseif op == :-
vec(W_val .- Y) ./ LinearAlgebra.norm(ref_mat)
else # :*
# d(norm(W .* W))/dW = 2 .* W .^ 3 / norm(W .* W)
vec(2 .* W_val .^ 3) ./ LinearAlgebra.norm(ref_mat)
end
@test g ≈ ref_g
end
return
end

# The fix replaced the old `(1, 1)` stub for `scalar .op matrix` with a copy
# of the matrix child's full shape. Eval/reverse for these mixed-rank
# broadcasts isn't implemented yet (and is out of scope for the fix), so we
# only assert the inferred shape — that's where the previous code was wrong.
function test_broadcast_scalar_matrix_size_inference()
model = Model()
@variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables)
mode = ArrayDiff.Mode()
@testset "$(name)" for (name, expr) in [
("scalar .* M", LinearAlgebra.norm(2.5 .* W)),
("M .* scalar", LinearAlgebra.norm(W .* 2.5)),
("scalar .+ M", LinearAlgebra.norm(2.5 .+ W)),
("M .+ scalar", LinearAlgebra.norm(W .+ 2.5)),
("scalar .- M", LinearAlgebra.norm(2.5 .- W)),
("M .- scalar", LinearAlgebra.norm(W .- 2.5)),
]
ad = ArrayDiff.model(mode)
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(expr))
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
# Broadcast node is at index 2; it should inherit the matrix child's
# (2, 3) shape, not the old `(1, 1)` stub.
@test sizes.ndims[2] == 2
broadcast_size_off = sizes.size_offset[2]
@test sizes.size[broadcast_size_off+1] == 2
@test sizes.size[broadcast_size_off+2] == 3
# And the scalar leaf among the children stays ndims=0.
@test 0 in sizes.ndims[3:4]
end
return
end

end # module

TestJuMP.runtests()
Loading