-
Notifications
You must be signed in to change notification settings - Fork 182
Linear algebra rewrites: diag sum rewrite #2022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v3
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,14 +18,15 @@ | |||||||||||||||||||||||||||||
| ExtractDiag, | ||||||||||||||||||||||||||||||
| Eye, | ||||||||||||||||||||||||||||||
| TensorVariable, | ||||||||||||||||||||||||||||||
| alloc_diag, | ||||||||||||||||||||||||||||||
| concatenate, | ||||||||||||||||||||||||||||||
| diag, | ||||||||||||||||||||||||||||||
| diagonal, | ||||||||||||||||||||||||||||||
| ones, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| from pytensor.tensor.blockwise import Blockwise | ||||||||||||||||||||||||||||||
| from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||||||||||||||||||||||||||||||
| from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod | ||||||||||||||||||||||||||||||
| from pytensor.tensor.math import Dot, Prod, _matmul, add, log, outer, prod | ||||||||||||||||||||||||||||||
| from pytensor.tensor.nlinalg import ( | ||||||||||||||||||||||||||||||
| SVD, | ||||||||||||||||||||||||||||||
| KroneckerProduct, | ||||||||||||||||||||||||||||||
|
|
@@ -1145,3 +1146,78 @@ def scalar_solve_to_division(fgraph, node): | |||||||||||||||||||||||||||||
| copy_stack_trace(old_out, new_out) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return [new_out] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _extract_diagonal(x): | ||||||||||||||||||||||||||||||
| """Return the diagonal entries when ``x`` is provably diagonal; otherwise ``None``. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| The supported patterns are: | ||||||||||||||||||||||||||||||
| - ``AllocDiag`` with zero offset | ||||||||||||||||||||||||||||||
| - elementwise multiplication with an identity matrix | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| if not x.owner: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if isinstance(x.owner.op, AllocDiag) and AllocDiag.is_offset_zero(x.owner): | ||||||||||||||||||||||||||||||
| return x.owner.inputs[0] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| inputs_or_none = _find_diag_from_eye_mul(x) | ||||||||||||||||||||||||||||||
| if inputs_or_none is None: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| eye_input, non_eye_inputs = inputs_or_none | ||||||||||||||||||||||||||||||
| if len(non_eye_inputs) != 1: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| [non_eye_input] = non_eye_inputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if non_eye_input.type.broadcastable[-2:] == (True, True): | ||||||||||||||||||||||||||||||
| scalar_input = non_eye_input.squeeze(axis=(-1, -2)) | ||||||||||||||||||||||||||||||
| if scalar_input.ndim == 0: | ||||||||||||||||||||||||||||||
| return scalar_input | ||||||||||||||||||||||||||||||
| # For batched scalar * eye, return batched diagonal entries (B, N), | ||||||||||||||||||||||||||||||
| # not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N). | ||||||||||||||||||||||||||||||
| return scalar_input[..., None] * pt.ones( | ||||||||||||||||||||||||||||||
| (eye_input.shape[-1],), dtype=scalar_input.dtype | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| if non_eye_input.type.broadcastable[-2:] == (False, False): | ||||||||||||||||||||||||||||||
| return non_eye_input.diagonal(axis1=-1, axis2=-2) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| squeeze_axis = -2 if non_eye_input.type.broadcastable[-2] else -1 | ||||||||||||||||||||||||||||||
| return non_eye_input.squeeze(axis=squeeze_axis) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @register_canonicalize | ||||||||||||||||||||||||||||||
| @register_stabilize | ||||||||||||||||||||||||||||||
| @node_rewriter([add]) | ||||||||||||||||||||||||||||||
| def rewrite_add_diag_to_diag_add(fgraph, node): | ||||||||||||||||||||||||||||||
| """Rewrite sums of diagonal matrices into one diagonal construction. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Uses ``diag(A + B) = diag(A) + diag(B)`` in reverse to avoid full matrix adds. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| old_out = node.outputs[0] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if old_out.type.ndim < 2: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| diagonal_inputs = [] | ||||||||||||||||||||||||||||||
| for inp in node.inputs: | ||||||||||||||||||||||||||||||
| diagonal_input = _extract_diagonal(inp) | ||||||||||||||||||||||||||||||
| if diagonal_input is None: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
| diagonal_inputs.append(diagonal_input) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| summed_diag = add(*diagonal_inputs) | ||||||||||||||||||||||||||||||
| if summed_diag.ndim == 0: | ||||||||||||||||||||||||||||||
| new_out = ( | ||||||||||||||||||||||||||||||
| pt.eye(old_out.shape[-2], old_out.shape[-1], dtype=old_out.dtype) | ||||||||||||||||||||||||||||||
| * summed_diag | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| else: | |
| else: | |
| # alloc_diag always constructs a square matrix based on the diagonal length. | |
| # Only apply this rewrite when the original output is guaranteed square, | |
| # to avoid silently changing shapes for rectangular diagonal matrices. | |
| out_type_shape = getattr(old_out.type, "shape", None) | |
| if ( | |
| out_type_shape is None | |
| or out_type_shape[-2] is None | |
| or out_type_shape[-1] is None | |
| or out_type_shape[-2] != out_type_shape[-1] | |
| ): | |
| # We cannot prove the output is square; skip this optimization. | |
| return None |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,9 +12,10 @@ | |||||||||||||||||||||
| from pytensor.configdefaults import config | ||||||||||||||||||||||
| from pytensor.graph import FunctionGraph, ancestors | ||||||||||||||||||||||
| from pytensor.graph.rewriting.utils import rewrite_graph | ||||||||||||||||||||||
| from pytensor.scalar.basic import Add | ||||||||||||||||||||||
| from pytensor.tensor import swapaxes | ||||||||||||||||||||||
| from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape | ||||||||||||||||||||||
| from pytensor.tensor.elemwise import DimShuffle | ||||||||||||||||||||||
| from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||||||||||||||||||||||
| from pytensor.tensor.math import dot, matmul | ||||||||||||||||||||||
| from pytensor.tensor.nlinalg import ( | ||||||||||||||||||||||
| SVD, | ||||||||||||||||||||||
|
|
@@ -1128,3 +1129,97 @@ def solve_op_in_graph(graph): | |||||||||||||||||||||
| np.testing.assert_allclose( | ||||||||||||||||||||||
| f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5 | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_add_diag_rewrite_from_diag_inputs(): | ||||||||||||||||||||||
| x = pt.vector("x") | ||||||||||||||||||||||
| y = pt.vector("y") | ||||||||||||||||||||||
| z = pt.diag(x) + pt.diag(y) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| rng = np.random.default_rng(sum(map(ord, "test_add_diag_rewrite_from_diag_inputs"))) | ||||||||||||||||||||||
| x_test = rng.normal(size=(7,)).astype(config.floatX) | ||||||||||||||||||||||
| y_test = rng.normal(size=(7,)).astype(config.floatX) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _assert_rewrite_add_diag_no_dense_add( | ||||||||||||||||||||||
| x=x, | ||||||||||||||||||||||
| y=y, | ||||||||||||||||||||||
| z=z, | ||||||||||||||||||||||
| out_ndim=2, | ||||||||||||||||||||||
| x_test=x_test, | ||||||||||||||||||||||
| y_test=y_test, | ||||||||||||||||||||||
| expected=np.diag(x_test + y_test), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _assert_rewrite_add_diag_no_dense_add(x, y, z, out_ndim, x_test, y_test, expected): | ||||||||||||||||||||||
| f_rewritten = function([x, y], z, mode="FAST_RUN") | ||||||||||||||||||||||
| nodes = f_rewritten.maker.fgraph.apply_nodes | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # The rewrite should avoid dense matrix additions. | ||||||||||||||||||||||
| has_dense_matrix_add = any( | ||||||||||||||||||||||
| isinstance(node.op, Elemwise) | ||||||||||||||||||||||
| and isinstance(node.op.scalar_op, Add) | ||||||||||||||||||||||
| and node.outputs[0].type.ndim == out_ndim | ||||||||||||||||||||||
| and node.outputs[0].type.broadcastable[-2:] == (False, False) | ||||||||||||||||||||||
| for node in nodes | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| assert not has_dense_matrix_add | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| assert_allclose(f_rewritten(x_test, y_test), expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||||||||||
| "case_name", | ||||||||||||||||||||||
| [ | ||||||||||||||||||||||
| "scalar_eye_mul", | ||||||||||||||||||||||
| "batched_scalar_eye_mul", | ||||||||||||||||||||||
| "batched_vector_eye_mul", | ||||||||||||||||||||||
| ], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| def test_add_diag_rewrite_for_eye_mul_cases(case_name): | ||||||||||||||||||||||
| rng = np.random.default_rng(sum(map(ord, f"test_add_diag_rewrite_for_{case_name}"))) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if case_name == "scalar_eye_mul": | ||||||||||||||||||||||
| x = pt.scalar("x") | ||||||||||||||||||||||
| y = pt.scalar("y") | ||||||||||||||||||||||
| z = pt.eye(5) * x + pt.eye(5) * y | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| x_test = np.asarray(rng.normal(), dtype=config.floatX) | ||||||||||||||||||||||
| y_test = np.asarray(rng.normal(), dtype=config.floatX) | ||||||||||||||||||||||
| expected = np.eye(5) * (x_test + y_test) | ||||||||||||||||||||||
| out_ndim = 2 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| elif case_name == "batched_scalar_eye_mul": | ||||||||||||||||||||||
| x = pt.vector("x") | ||||||||||||||||||||||
| y = pt.vector("y") | ||||||||||||||||||||||
| z = pt.eye(5) * x[:, None, None] + pt.eye(5) * y[:, None, None] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| x_test = rng.normal(size=(4,)).astype(config.floatX) | ||||||||||||||||||||||
| y_test = rng.normal(size=(4,)).astype(config.floatX) | ||||||||||||||||||||||
| expected = np.eye(5)[None, :, :] * ( | ||||||||||||||||||||||
| x_test[:, None, None] + y_test[:, None, None] | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| out_ndim = 3 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| elif case_name == "batched_vector_eye_mul": | ||||||||||||||||||||||
| x = pt.matrix("x") | ||||||||||||||||||||||
| y = pt.matrix("y") | ||||||||||||||||||||||
| z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | ||||||||||||||||||||||
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | ||||||||||||||||||||||
| expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) | ||||||||||||||||||||||
|
Comment on lines
+1207
to
+1211
|
||||||||||||||||||||||
| z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :] | |
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) | |
| z = pt.eye(3, 5) * x[:, None, :] + pt.eye(3, 5) * y[:, None, :] | |
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| expected = np.eye(3, 5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For
eye * xpatterns,_extract_diagonalbuilds a length-eye_input.shape[-1]diagonal vector for the batched-scalar case. Ifeye_inputis rectangular (n != m), the true diagonal length ismin(n, m), and reconstructing viaalloc_diaglater will produce the wrong shape/values. Consider either restricting this rewrite to squareEyeinputs (or squareold_out) or changing the representation/reconstruction to preserve rectangular identity shapes.