Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/parse_moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,24 @@ function _parse_moi_stack!(
parent_index::Int,
)
m, n = x.size
# Build vcat(row(v11, v12, ...), row(v21, v22, ...), ...)
# Build vcat(row(v11, v12, ...), row(v21, v22, ...), ...).
#
# The outer loop is `1:m` (forward order), NOT `m:-1:1`. The `:row` nodes
# we push end up at consecutive positions in `expr.nodes`, and `:vcat`
# later reads its children in tape-index order (CSC `nzrange`) — so the
# row with the smallest tape index becomes row 1 of the output matrix.
# If the outer loop ran in reverse, `row_m` would land at the smallest
# tape index and `:vcat` would silently place it as row 1, producing a
# row-flipped matrix on the tape (a latent bug, fixed here).
#
# The inner loop stays `n:-1:1` because the items go on the stack and pop
# in LIFO order — pushing in reverse j order gives forward j-order on
# pop, which matches the column-major layout below.
vcat_id = data.operators.multivariate_operator_to_id[:vcat]
row_id = data.operators.multivariate_operator_to_id[:row]
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vcat_id, parent_index))
vcat_idx = length(expr.nodes)
# Push rows in reverse order for stack processing
for i in m:-1:1
for i in 1:m
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, row_id, vcat_idx))
row_idx = length(expr.nodes)
for j in n:-1:1
Expand Down Expand Up @@ -192,11 +203,14 @@ function _parse_moi_stack!(
parent_index::Int,
)
m, n = size(x)
# See the `ArrayOfContiguousVariables{2}` overload for the rationale on
# the `1:m` outer loop (the previous `m:-1:1` produced a row-flipped
# matrix on the tape).
vcat_id = data.operators.multivariate_operator_to_id[:vcat]
row_id = data.operators.multivariate_operator_to_id[:row]
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vcat_id, parent_index))
vcat_idx = length(expr.nodes)
for i in m:-1:1
for i in 1:m
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, row_id, vcat_idx))
row_idx = length(expr.nodes)
for j in n:-1:1
Expand Down
54 changes: 41 additions & 13 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,29 +241,57 @@ function _test_neural(
W1_val = [0.3 -0.2; 0.1 0.4]
W2_val = [-0.1 0.5; 0.2 -0.3]
obj, g = _eval(model, loss, [vec(W1_val); vec(W2_val)])
obj_val = 0.8516435891643307
# Reference computed from the same hand-written forward/reverse formulas
# as `perf/cuda_vs_pytorch.jl::forward_pass`/`reverse_diff`, adapted to
# this test's loss `sum((Y - target).^2)` (no `/ n` scaling, full gradient
# over both `W1` and `W2`). `_eval` evaluates the objective at `xstart`
# and the gradient at `x = [1, ..., 8]`, so we need the references at the
# corresponding inputs.
X_const = [1.0 0.5; 0.3 0.8]
target_const = [0.5 0.2; 0.1 0.7]
obj_val = _ref_objective(W1_val, W2_val, X_const, target_const)
if with_norm
obj_val = sqrt(obj_val)
end
@test obj ≈ obj_val
grad = [
12.3913945850742
0.6880048864793
9.4322503589489
0.5223651220724
46.2269560438734
53.9729454980064
45.7401048264386
53.4195902684781
]
W1_at_grad = reshape([1.0, 2.0, 3.0, 4.0], 2, 2)
W2_at_grad = reshape([5.0, 6.0, 7.0, 8.0], 2, 2)
grad_sumsq = _ref_gradient(W1_at_grad, W2_at_grad, X_const, target_const)
if with_norm
@test g ≈ grad * 0.019879429552408144
# `d/dx ‖E‖₂ = (1/(2‖E‖₂)) · d/dx ‖E‖₂² = grad_sumsq / (2 sqrt(sumsq))`,
# taken at the gradient evaluation point.
norm_at_grad =
sqrt(_ref_objective(W1_at_grad, W2_at_grad, X_const, target_const))
@test g ≈ grad_sumsq ./ (2 * norm_at_grad)
else
@test g ≈ grad
@test g ≈ grad_sumsq
end
return
end

# Hand-written forward + reverse for the 2-layer MLP `loss = sum((W2 *
# tanh.(W1 * X) - target).^2)`. Same shape as `perf/cuda_vs_pytorch.jl`'s
# `forward_pass` / `reverse_diff` but adapted to this test (no `/ n` scaling
# and gradient over both `W1` and `W2`). Returned gradient is flattened with
# the JuMP variable convention `[vec(grad_W1); vec(grad_W2)]`.
function _ref_forward(W1, W2, X, target)
y_1 = tanh.(W1 * X)
J_1 = 1 .- y_1 .^ 2
J_2 = 2 .* (W2 * y_1 .- target)
return y_1, J_1, J_2
end

function _ref_objective(W1, W2, X, target)
return sum((W2 * tanh.(W1 * X) .- target) .^ 2)
end

function _ref_gradient(W1, W2, X, target)
y_1, J_1, J_2 = _ref_forward(W1, W2, X, target)
grad_W1 = (J_1 .* (W2' * J_2)) * X'
grad_W2 = J_2 * y_1'
return [vec(grad_W1); vec(grad_W2)]
end

function test_neural()
bin = [false, true]
@testset "$(with_norm ? "norm" : "sum")" for with_norm in bin
Expand Down
Loading