From fb1a2535fdf8a475f30f5039c386c88f1ee8b673 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Mon, 23 Mar 2026 21:52:17 -0500 Subject: [PATCH] Add slogdet rewrite for block diagonal matrices Decompose slogdet(block_diag(A, B, ...)) into (prod(sign_i), sum(logabsdet_i)) where each (sign_i, logabsdet_i) = slogdet(A_i). This avoids constructing and factoring the full block diagonal matrix when the slogdet of each block can be computed independently. Co-Authored-By: Claude Opus 4.6 --- pytensor/tensor/rewriting/linalg.py | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2c17020cd9..d31a5799b8 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -895,6 +895,49 @@ def rewrite_det_blockdiag(fgraph, node): return [prod(det_sub_matrices)] +@register_canonicalize +@register_stabilize +@node_rewriter([SLogDet]) +def rewrite_slogdet_blockdiag(fgraph, node): + """ + This rewrite simplifies the slogdet of a blockdiagonal matrix by computing + slogdet of each sub-matrix independently. + + slogdet(block_diag(a,b,c,...)) = (prod(sign(a), sign(b), ...), sum(logabsdet(a), logabsdet(b), ...)) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_block_diag.inputs + signs = [] + logabsdets = [] + for sub_mat in sub_matrices: + sign_i, logabsdet_i = SLogDet()(sub_mat) + signs.append(sign_i) + logabsdets.append(logabsdet_i) + + return [prod(signs), pt.add(*logabsdets)] + + @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag])