Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
ExtractDiag,
Eye,
TensorVariable,
alloc_diag,
concatenate,
diag,
diagonal,
ones,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.math import Dot, Prod, _matmul, add, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
Expand Down Expand Up @@ -1145,3 +1146,78 @@ def scalar_solve_to_division(fgraph, node):
copy_stack_trace(old_out, new_out)

return [new_out]


def _extract_diagonal(x):
"""Return the diagonal entries when ``x`` is provably diagonal; otherwise ``None``.

The supported patterns are:
- ``AllocDiag`` with zero offset
- elementwise multiplication with an identity matrix
"""
if not x.owner:
return None

if isinstance(x.owner.op, AllocDiag) and AllocDiag.is_offset_zero(x.owner):
return x.owner.inputs[0]

inputs_or_none = _find_diag_from_eye_mul(x)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none
if len(non_eye_inputs) != 1:
return None

[non_eye_input] = non_eye_inputs

if non_eye_input.type.broadcastable[-2:] == (True, True):
scalar_input = non_eye_input.squeeze(axis=(-1, -2))
if scalar_input.ndim == 0:
return scalar_input
# For batched scalar * eye, return batched diagonal entries (B, N),
# not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N).
return scalar_input[..., None] * pt.ones(
(eye_input.shape[-1],), dtype=scalar_input.dtype
Comment on lines +1178 to +1181
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eye * x patterns, _extract_diagonal builds a length-eye_input.shape[-1] diagonal vector for the batched-scalar case. If eye_input is rectangular (n != m), the true diagonal length is min(n, m), and reconstructing via alloc_diag later will produce the wrong shape/values. Consider either restricting this rewrite to square Eye inputs (or square old_out) or changing the representation/reconstruction to preserve rectangular identity shapes.

Suggested change
# For batched scalar * eye, return batched diagonal entries (B, N),
# not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N).
return scalar_input[..., None] * pt.ones(
(eye_input.shape[-1],), dtype=scalar_input.dtype
# For batched scalar * eye, return batched diagonal entries (B, K),
# where K = min(n, m), not batch scalars (B), so downstream alloc_diag
# reconstructs the correct diagonal shape.
diag_len = pt.minimum(eye_input.shape[-2], eye_input.shape[-1])
return scalar_input[..., None] * pt.ones(
(diag_len,), dtype=scalar_input.dtype

Copilot uses AI. Check for mistakes.
)
if non_eye_input.type.broadcastable[-2:] == (False, False):
return non_eye_input.diagonal(axis1=-1, axis2=-2)

squeeze_axis = -2 if non_eye_input.type.broadcastable[-2] else -1
return non_eye_input.squeeze(axis=squeeze_axis)


@register_canonicalize
@register_stabilize
@node_rewriter([add])
def rewrite_add_diag_to_diag_add(fgraph, node):
"""Rewrite sums of diagonal matrices into one diagonal construction.

Uses ``diag(A + B) = diag(A) + diag(B)`` in reverse to avoid full matrix adds.
"""
old_out = node.outputs[0]

if old_out.type.ndim < 2:
return None

diagonal_inputs = []
for inp in node.inputs:
diagonal_input = _extract_diagonal(inp)
if diagonal_input is None:
return None
diagonal_inputs.append(diagonal_input)

summed_diag = add(*diagonal_inputs)
if summed_diag.ndim == 0:
new_out = (
pt.eye(old_out.shape[-2], old_out.shape[-1], dtype=old_out.dtype)
* summed_diag
)
else:
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alloc_diag(summed_diag, axis1=-2, axis2=-1) always constructs a square matrix of size summed_diag.shape[-1]. If this rewrite triggers for diagonal matrices coming from Eye(n, m) * ... with n != m (especially the batched cases where summed_diag.ndim > 0), the replacement will silently change the output shape from (…, n, m) to (…, m, m). Add a guard to only apply this rewrite when the output is guaranteed square (or implement a rectangular-diagonal reconstruction path instead of alloc_diag).

Suggested change
else:
else:
# alloc_diag always constructs a square matrix based on the diagonal length.
# Only apply this rewrite when the original output is guaranteed square,
# to avoid silently changing shapes for rectangular diagonal matrices.
out_type_shape = getattr(old_out.type, "shape", None)
if (
out_type_shape is None
or out_type_shape[-2] is None
or out_type_shape[-1] is None
or out_type_shape[-2] != out_type_shape[-1]
):
# We cannot prove the output is square; skip this optimization.
return None

Copilot uses AI. Check for mistakes.
new_out = alloc_diag(summed_diag, axis1=-2, axis2=-1)

if new_out.dtype != old_out.dtype:
new_out = pt.cast(new_out, old_out.dtype)

copy_stack_trace(old_out, new_out)
return [new_out]
97 changes: 96 additions & 1 deletion tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.scalar.basic import Add
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -1128,3 +1129,97 @@ def solve_op_in_graph(graph):
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)


def test_add_diag_rewrite_from_diag_inputs():
x = pt.vector("x")
y = pt.vector("y")
z = pt.diag(x) + pt.diag(y)

rng = np.random.default_rng(sum(map(ord, "test_add_diag_rewrite_from_diag_inputs")))
x_test = rng.normal(size=(7,)).astype(config.floatX)
y_test = rng.normal(size=(7,)).astype(config.floatX)

_assert_rewrite_add_diag_no_dense_add(
x=x,
y=y,
z=z,
out_ndim=2,
x_test=x_test,
y_test=y_test,
expected=np.diag(x_test + y_test),
)


def _assert_rewrite_add_diag_no_dense_add(x, y, z, out_ndim, x_test, y_test, expected):
f_rewritten = function([x, y], z, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

# The rewrite should avoid dense matrix additions.
has_dense_matrix_add = any(
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, Add)
and node.outputs[0].type.ndim == out_ndim
and node.outputs[0].type.broadcastable[-2:] == (False, False)
for node in nodes
)
assert not has_dense_matrix_add

assert_allclose(f_rewritten(x_test, y_test), expected)


@pytest.mark.parametrize(
"case_name",
[
"scalar_eye_mul",
"batched_scalar_eye_mul",
"batched_vector_eye_mul",
],
)
def test_add_diag_rewrite_for_eye_mul_cases(case_name):
rng = np.random.default_rng(sum(map(ord, f"test_add_diag_rewrite_for_{case_name}")))

if case_name == "scalar_eye_mul":
x = pt.scalar("x")
y = pt.scalar("y")
z = pt.eye(5) * x + pt.eye(5) * y

x_test = np.asarray(rng.normal(), dtype=config.floatX)
y_test = np.asarray(rng.normal(), dtype=config.floatX)
expected = np.eye(5) * (x_test + y_test)
out_ndim = 2

elif case_name == "batched_scalar_eye_mul":
x = pt.vector("x")
y = pt.vector("y")
z = pt.eye(5) * x[:, None, None] + pt.eye(5) * y[:, None, None]

x_test = rng.normal(size=(4,)).astype(config.floatX)
y_test = rng.normal(size=(4,)).astype(config.floatX)
expected = np.eye(5)[None, :, :] * (
x_test[:, None, None] + y_test[:, None, None]
)
out_ndim = 3

elif case_name == "batched_vector_eye_mul":
x = pt.matrix("x")
y = pt.matrix("y")
z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :]

x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])
Comment on lines +1207 to +1211
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parametrized eye * x tests only cover square pt.eye(5) inputs. Given the new rewrite matches Eye patterns, it would be good to add a non-square pt.eye(n, m) case (e.g. n!=m) to ensure the rewrite either does not apply or preserves the correct (n, m) output shape—this would catch the rectangular-shape failure mode introduced by the alloc_diag reconstruction path.

Suggested change
z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :]
x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])
z = pt.eye(3, 5) * x[:, None, :] + pt.eye(3, 5) * y[:, None, :]
x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(3, 5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])

Copilot uses AI. Check for mistakes.
out_ndim = 3

else: # pragma: no cover
raise ValueError(f"Unexpected case_name: {case_name}")

_assert_rewrite_add_diag_no_dense_add(
x=x,
y=y,
z=z,
out_ndim=out_ndim,
x_test=x_test,
y_test=y_test,
expected=expected,
)
Loading