From 75cae024bc1b56c2acf85e5fd46ebe4eb38130a4 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 31 Mar 2026 23:24:26 +0200 Subject: [PATCH 1/7] Elemwise: fuse sibling subgraphs in FusionOptimizer Extend FusionOptimizer to merge independent subgraphs that share inputs but have no producer-consumer edge (siblings like f(x) and g(x)). The eager expansion only walks producer-consumer edges, missing these. Also extract InplaceGraphOptimizer.try_inplace_on_node helper and _insert_sorted_subgraph to deduplicate insertion-point logic. --- pytensor/tensor/rewriting/elemwise.py | 178 +++++++++++++++++++----- tests/tensor/rewriting/test_elemwise.py | 17 +-- 2 files changed, 153 insertions(+), 42 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index b404dba4ee..56cf7a3d8f 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,6 +2,7 @@ import itertools import operator import sys +from collections import defaultdict from collections.abc import Generator, Sequence from functools import cache, reduce from heapq import heapify, heappop, heappush @@ -555,6 +556,26 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time + def _insert_sorted_subgraph( + sorted_subgraphs, all_subgraphs_bitset, ancestors_bitset, entry + ): + """Insert a subgraph entry into sorted_subgraphs respecting topological order.""" + if not (ancestors_bitset & all_subgraphs_bitset): + # No dependency on previous subgraphs, append at the end + sorted_subgraphs.append(entry) + else: + # Find the right insertion point by iterating in topological + # order (reverse of stored order), cumulatively excluding each + # subgraph until the dependency check passes. + remaining_bitset = all_subgraphs_bitset + for index, (other_bitset, _) in enumerate(reversed(sorted_subgraphs)): + remaining_bitset &= ~other_bitset + if not (ancestors_bitset & remaining_bitset): + break + else: + raise RuntimeError("Failed to find insertion point for subgraph") + sorted_subgraphs.insert(-(index + 1), entry) + def find_fuseable_subgraphs( fg: FunctionGraph, ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: @@ -832,43 +853,136 @@ def elemwise_scalar_op_has_c_code( if node_ancestors_bitset & subgraph_bitset ) - # Add new subgraph to sorted_subgraphs - # Because we start from sink nodes in reverse topological order, most times new subgraphs - # don't depend on previous subgraphs, so we can just append them at the end. - if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): - # That's the case here - # None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs - sorted_subgraphs.append( - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) + _insert_sorted_subgraph( + sorted_subgraphs, + all_subgraphs_bitset, + unfuseable_ancestors_bitset, + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), + ) + + all_subgraphs_bitset |= subgraph_bitset + + # Merge sibling groups: independent subgraphs or remaining + # candidate nodes that share inputs but have no producer-consumer + # edge between them. The eager expansion above only walks + # producer-consumer edges, so it misses siblings like f(x) and + # g(x) that share an input without one feeding into the other. + sibling_candidates: list = list(sorted_subgraphs) + for node in candidate_starting_nodes: + bf = nodes_bitflags.get(node) + if (bf is not None) and not (bf & all_subgraphs_bitset): + sibling_candidates.append( + (bf, (tuple(dict.fromkeys(node.inputs)), node.outputs)) ) - else: - # But not here, so we need to find the right position for insertion. - # We iterate through the previous subgraphs in topological order (reverse of the stored order). - # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes. - remaining_subgraphs_bitset = all_subgraphs_bitset - for index, (other_subgraph_bitset, _) in enumerate( - reversed(sorted_subgraphs) - ): - # Exclude subgraph bitset - remaining_subgraphs_bitset &= ~other_subgraph_bitset - if not ( - unfuseable_ancestors_bitset & remaining_subgraphs_bitset + # Create a mapping from inputs to sibling groups that consume them. + # Skip scalar constants as they get inlined later and don't + # represent meaningful shared computation between siblings. + input_to_sibling_idxs: dict[Variable, list[int]] = defaultdict(list) + for sibling_idx, (_, (inputs, _)) in enumerate(sibling_candidates): + for inp in inputs: + if isinstance(inp, TensorConstant) and inp.unique_value is not None: + continue + input_to_sibling_idxs[inp].append(sibling_idx) + + for sibling_idxs in input_to_sibling_idxs.values(): + if len(sibling_idxs) < 2: + continue + + for i in range(len(sibling_idxs) - 1): + sibling_i = sibling_idxs[i] + if sibling_candidates[sibling_i] is None: + # Already merged in a previous iteration + continue + + bitset_i, (inputs_i, outputs_i) = sibling_candidates[sibling_i] + bcast_i = outputs_i[0].type.broadcastable + merged = False + for sibling_j in sibling_idxs[i + 1 :]: + if sibling_candidates[sibling_j] is None: + # Already merged in a previous iteration + continue + bitset_j, (inputs_j, outputs_j) = sibling_candidates[sibling_j] + if bcast_i != outputs_j[0].type.broadcastable: + continue + # Independence: neither is ancestor of the other + if ( + bitset_i & ancestors_bitsets[outputs_j[0].owner] + or bitset_j & ancestors_bitsets[outputs_i[0].owner] ): - break # bingo - else: # no-break - raise RuntimeError( - "Failed to find insertion point for fused subgraph" + continue + + # Merge sibling_j into sibling_i + merged_bitset = bitset_i | bitset_j + merged_outputs = (*outputs_i, *outputs_j) + merged_inputs = list(inputs_i) + merged_inputs_set = set(inputs_i) + for input_j in inputs_j: + if input_j not in merged_inputs_set: + merged_inputs.append(input_j) + merged_inputs_set.add(input_j) + # Hide sibling_j from future merges + sibling_candidates[sibling_j] = None + + # Update ancestor bitsets so that any node + # depending on part of the merged group now + # depends on all of it (mirrors the main loop). + merged_ancestors = reduce( + or_, + (ancestors_bitsets[o.owner] for o in merged_outputs), + ) + ancestors_bitsets |= ( + (n, n_anc | merged_ancestors) + for n, n_anc in ancestors_bitsets.items() + if n_anc & merged_bitset ) - sorted_subgraphs.insert( - -(index + 1), - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), - ) - # Add subgraph to all_subgraphs_bitset - all_subgraphs_bitset |= subgraph_bitset + # Update locals for next iteration + bitset_i = merged_bitset + inputs_i = tuple(merged_inputs) + outputs_i = merged_outputs + merged = True + + if merged: + # Write back so other input lists see the merged state + sibling_candidates[sibling_i] = ( + bitset_i, + (inputs_i, outputs_i), + ) + + # Update sorted_subgraphs + merged_entry = (bitset_i, (inputs_i, outputs_i)) + + # Replace earliest existing sibling with merged graph; drop the rest + pos = next( + ( + sg_idx + for sg_idx, (sg_bitset, _) in enumerate( + sorted_subgraphs + ) + if sg_bitset & bitset_i + ), + None, + ) + if pos is not None: + sorted_subgraphs[:] = [ + *sorted_subgraphs[:pos], + merged_entry, + *( + e + for e in sorted_subgraphs[pos + 1 :] + if not (e[0] & bitset_i) + ), + ] + else: + # All siblings were singletons: need a sorted insertion + _insert_sorted_subgraph( + sorted_subgraphs, + all_subgraphs_bitset, + ancestors_bitsets[outputs_i[0].owner], + merged_entry, + ) + all_subgraphs_bitset |= bitset_i - # Finished exploring the whole graph - # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) max_operands = elemwise_max_operands_fct(None) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 6dbaa75b85..4d77b17268 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -996,16 +996,13 @@ def test_expansion_order(self): (np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),), ("float32",), ), - pytest.param( - ( - (sin(exp(fx)), exp(sin(fx))), - (fx,), - (fxv,), - 1, - (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), - ("float32", "float32"), - ), - marks=pytest.mark.xfail, # Not implemented yet + ( + (sin(exp(fx)), exp(sin(fx))), + (fx,), + (fxv,), + 1, + (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), + ("float32", "float32"), ), ], ) From 0b71930a1edb6cb7303d9cd1a7c24641a844f322 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Mar 2026 19:09:50 +0200 Subject: [PATCH 2/7] Numba: fuse AdvancedSubtensor1 indexed reads into Elemwise Fuse single-client AdvancedSubtensor1 nodes into Elemwise loops, replacing indirect array reads with a single iteration loop that uses index arrays for input access. Before (2 nodes): temp = x[idx] # AdvancedSubtensor1, shape (919,) result = temp + y # Elemwise After (1 fused loop, x is read directly via idx): for k in range(919): result[k] = x[idx[k]] + y[k] - Introduce IndexedElemwise Op (in rewriting/indexed_elemwise.py) - Add FuseIndexedElemwise rewrite with SequenceDB - Event _vectorized intrinsic to support indirect readings and writings - Add op_debug_information for dprint(print_op_info=True) --- pytensor/link/numba/dispatch/blockwise.py | 7 +- pytensor/link/numba/dispatch/elemwise.py | 99 ++- pytensor/link/numba/dispatch/random.py | 7 +- .../link/numba/dispatch/vectorize_codegen.py | 616 +++++++++++++----- pytensor/tensor/elemwise.py | 18 +- pytensor/tensor/rewriting/elemwise.py | 150 +++-- pytensor/tensor/rewriting/indexed_elemwise.py | 236 +++++++ pytensor/tensor/rewriting/numba.py | 2 + tests/benchmarks/test_gather_fusion.py | 75 +++ tests/link/numba/test_indexed_elemwise.py | 125 ++++ 10 files changed, 1091 insertions(+), 244 deletions(-) create mode 100644 pytensor/tensor/rewriting/indexed_elemwise.py create mode 100644 tests/benchmarks/test_gather_fusion.py create mode 100644 tests/link/numba/test_indexed_elemwise.py diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index f2405e3542..af4aa4aee3 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -11,6 +11,9 @@ register_funcify_and_cache_key, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, _jit_options, _vectorized, encode_literals, @@ -90,7 +93,9 @@ def impl(*inputs_and_core_shapes): (), # constant_inputs inputs, tuple_core_shapes, - None, # size + NO_SIZE, + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return impl diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2308ddffdc..393d4d6266 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -20,6 +20,10 @@ ) from pytensor.link.numba.dispatch.string_codegen import create_tuple_string from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, + _jit_options, _vectorized, encode_literals, store_core_outputs, @@ -35,13 +39,13 @@ Mul, Sub, TrueDiv, - get_scalar_type, maximum, ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -312,8 +316,7 @@ def axis_apply_fn(x): @register_funcify_and_cache_key(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] - scalar_node = op.scalar_op.make_node(*scalar_inputs) + scalar_node = op.make_scalar_node(*node.inputs) scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( op.scalar_op, node=scalar_node, @@ -368,8 +371,10 @@ def impl(*inputs): True, # allow_core_scalar (), # constant_inputs inputs, - core_output_shapes, # core_shapes - None, # size + core_output_shapes, + NO_SIZE, + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return impl @@ -390,6 +395,90 @@ def impl(*inputs): return elemwise, elemwise_key +@register_funcify_and_cache_key(IndexedElemwise) +def numba_funcify_IndexedElemwise(op, node, **kwargs): + """Generate fused Elemwise Numba code with indexed reads. + + Reads indexed_inputs specs stored on the Op by the rewriting pass, + and generates a single vectorized loop with indirect indexing. + + Outer inputs are ordered as:: + + [elemwise_inputs..., idx_0, idx_1, ...] + """ + [elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)] + + indexed_inputs = op.indexed_inputs + indexed_outputs_enc = encode_literals(()) + + # --- Scalar function and core op -------------------------------------- + scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs) + scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( + elemwise_node.op.scalar_op, node=scalar_node, **kwargs + ) + + nin_elemwise = len(elemwise_node.inputs) + nout = len(elemwise_node.outputs) + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin_elemwise, nout=nout) + + # --- Broadcast and type encodings ------------------------------------- + input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) + output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) + output_dtypes = tuple(out.type.dtype for out in node.outputs) + inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) + core_output_shapes = tuple(() for _ in range(nout)) + + input_bc_patterns_enc = encode_literals(input_bc_patterns) + output_bc_patterns_enc = encode_literals(output_bc_patterns) + output_dtypes_enc = encode_literals(output_dtypes) + inplace_pattern_enc = encode_literals(inplace_pattern) + indexed_inputs_enc = encode_literals(indexed_inputs) + + def indexed_elemwise_fn(*outer_inputs): + raise NotImplementedError( + "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." + ) + + @overload(indexed_elemwise_fn, jit_options=_jit_options) + def ov_indexed_elemwise_fn(*outer_inputs): + def impl(*outer_inputs): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + outer_inputs, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + ) + + return impl + + cache_version = (0, 1) + if scalar_cache_key is None: + key = None + else: + key = str( + ( + type(op), + "IndexedElemwise", + cache_version, + inplace_pattern, + input_bc_patterns, + indexed_inputs, + scalar_cache_key, + ) + ) + key = sha256(key.encode()).hexdigest() + + return indexed_elemwise_fn, key + + @register_funcify_and_cache_key(CAReduce) def numba_funcify_CAReduce(op, node, **kwargs): axes = op.axis diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 0ac6a301cc..09e69603f6 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -21,6 +21,9 @@ ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, _jit_options, _vectorized, encode_literals, @@ -474,9 +477,11 @@ def impl(core_shape, rng, size, *dist_params): (rng,), dist_params, (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), - None + NO_SIZE if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return rng, draws diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index f804c0c04c..3f74dad09e 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import operator import pickle from collections.abc import Callable, Sequence from textwrap import indent @@ -13,12 +14,29 @@ from numba.core import cgutils from numba.core.base import BaseContext from numba.core.types.misc import NoneType +from numba.extending import overload from numba.np import arrayobj from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +# Numba is missing getitem(0d_array, Ellipsis), so o[...] += val fails. +# Register it so store_core_outputs can use o[...] += val naturally. +@overload(operator.getitem, inline="always") +def _getitem_0d_ellipsis(arr, idx): + if ( + isinstance(arr, types.Array) + and arr.ndim == 0 + and isinstance(idx, types.EllipsisType) + ): + + def impl(arr, idx): + return arr[()] + + return impl + + def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() @@ -30,10 +48,7 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) o0[...] = to0 - o1[...] = to1 ... - on[...] = ton - """ inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] @@ -74,73 +89,27 @@ def store_core_outputs({inp_signature}, {out_signature}): } -@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) -def _vectorized( - typingctx, - core_func, +def _decode_literal(val, name): + if not isinstance(val, types.Literal): + raise TypingError(f"{name} must be literal.") + return pickle.loads(base64.decodebytes(val.literal_value.encode())) + + +def _compute_vectorized_types( + input_types, input_bc_patterns, output_bc_patterns, output_dtypes, inplace_pattern, allow_core_scalar, - constant_inputs_types, - input_types, output_core_shape_types, - size_type, ): - arg_types = [ - core_func, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, - allow_core_scalar, - constant_inputs_types, - input_types, - output_core_shape_types, - size_type, - ] - - if not isinstance(input_bc_patterns, types.Literal): - raise TypingError("input_bc_patterns must be literal.") - input_bc_patterns = input_bc_patterns.literal_value - input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) - - if not isinstance(output_bc_patterns, types.Literal): - raise TypeError("output_bc_patterns must be literal.") - output_bc_patterns = output_bc_patterns.literal_value - output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) - - if not isinstance(output_dtypes, types.Literal): - raise TypeError("output_dtypes must be literal.") - output_dtypes = output_dtypes.literal_value - output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) - - if not isinstance(inplace_pattern, types.Literal): - raise TypeError("inplace_pattern must be literal.") - inplace_pattern = inplace_pattern.literal_value - inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - - if not isinstance(allow_core_scalar, types.Literal): - raise TypeError("allow_core_scalar must be literal.") - allow_core_scalar = allow_core_scalar.literal_value - + """Compute core input/output types and return type for vectorized intrinsics.""" batch_ndim = len(input_bc_patterns[0]) - nin = len(constant_inputs_types) + len(input_types) - nout = len(output_bc_patterns) - - if nin == 0: - raise TypingError("Empty argument list to vectorized op.") - if nout == 0: - raise TypingError("Empty list of outputs for vectorized op.") - - if not all(isinstance(input, types.Array) for input in input_types): + if not all(isinstance(t, types.Array) for t in input_types): raise TypingError("Vectorized inputs must be arrays.") - - if not all( - len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns - ): + if not all(len(p) == batch_ndim for p in input_bc_patterns + output_bc_patterns): raise TypingError( "Vectorized broadcastable patterns must have the same length." ) @@ -149,12 +118,15 @@ def _vectorized( for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): core_ndim = input_type.ndim - len(bc_pattern) if allow_core_scalar and core_ndim == 0: - core_input_type = input_type.dtype + core_input_types.append(input_type.dtype) else: - core_input_type = types.Array( - dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + # FIXME: inheriting layout from the full array is wrong for F-order + # inputs with batch dims — the core slice won't be F-contiguous. + core_input_types.append( + types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) ) - core_input_types.append(core_input_type) core_out_types = [ types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") @@ -172,115 +144,45 @@ def _vectorized( ) ] - for output_idx, input_idx in inplace_pattern: - output_type = input_types[input_idx] + for output_idx, inp_idx in inplace_pattern: + output_type = input_types[inp_idx] core_out_types[output_idx] = types.Array( dtype=output_type.dtype, ndim=output_type.ndim - batch_ndim, - layout=input_type.layout, + layout=output_type.layout, ) out_types[output_idx] = output_type ret_type = types.Tuple(out_types) - if len(output_dtypes) == 1: ret_type = ret_type.types[0] - sig = ret_type(*arg_types) - # So we can access the constant values in codegen... - input_bc_patterns_val = input_bc_patterns - output_bc_patterns_val = output_bc_patterns - output_dtypes_val = output_dtypes - inplace_pattern_val = inplace_pattern - input_types = input_types - size_is_none = isinstance(size_type, NoneType) - - def codegen( - ctx, - builder, - sig, - args, - ): - [_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args + return core_input_types, core_out_types, out_types, ret_type - constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) - inputs = cgutils.unpack_tuple(builder, inputs) - output_core_shapes = [ - cgutils.unpack_tuple(builder, shape) - for shape in cgutils.unpack_tuple(builder, output_core_shapes) - ] - size = None if size_is_none else cgutils.unpack_tuple(builder, size) - - inputs = [ - arrayobj.make_array(ty)(ctx, builder, val) - for ty, val in zip(input_types, inputs, strict=True) - ] - in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] - - iter_shape = compute_itershape( - ctx, - builder, - in_shapes, - input_bc_patterns_val, - size, - ) - outputs, output_types = make_outputs( - ctx, - builder, - iter_shape, - output_bc_patterns_val, - output_dtypes_val, - inplace_pattern_val, - inputs, - input_types, - output_core_shapes, - ) +def _codegen_return_outputs(ctx, builder, sig, outputs, inplace_pattern): + """Generate LLVM IR to return output arrays, handling incref for inplace.""" + incref_set = set(dict(inplace_pattern).keys()) - core_signature = typingctx.resolve_function_type( - core_func, - [ - *constant_inputs_types, - *core_input_types, - *core_out_types, - ], - {}, - ) + if len(outputs) == 1: + if incref_set: + ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) + return outputs[0]._getvalue() - make_loop_call( - typingctx, - ctx, + for idx in sorted(incref_set): + ctx.nrt.incref( builder, - core_func, - core_signature, - iter_shape, - constant_inputs, - inputs, - outputs, - input_bc_patterns_val, - output_bc_patterns_val, - input_types, - output_types, - core_scalar=allow_core_scalar, + sig.return_type.types[idx], + outputs[idx]._getvalue(), ) + return ctx.make_tuple( + builder, sig.return_type, [out._getvalue() for out in outputs] + ) - if len(outputs) == 1: - if inplace_pattern: - assert inplace_pattern[0][0] == 0 - ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) - return outputs[0]._getvalue() - - for inplace_idx in dict(inplace_pattern): - ctx.nrt.incref( - builder, - sig.return_type.types[inplace_idx], - outputs[inplace_idx]._getvalue(), - ) - return ctx.make_tuple( - builder, sig.return_type, [out._getvalue() for out in outputs] - ) - return sig, codegen +NO_INDEXED_INPUTS = encode_literals(()) +NO_INDEXED_OUTPUTS = encode_literals(()) +NO_SIZE = None def compute_itershape( @@ -386,6 +288,7 @@ def make_outputs( input_types: tuple[Any, ...], output_core_shapes: tuple, ) -> tuple[list[ir.Value], list[types.Array]]: + """Allocate output arrays for vectorized loop.""" output_arrays = [] output_arry_types = [] one = ir.IntType(64)(1) @@ -412,8 +315,8 @@ def make_outputs( array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) output_arrays.append(array) - # If there is no inplace operation, we know that all output arrays - # don't alias. Informing llvm can make it easier to vectorize. + # If there is no inplace operation, we know that all output + # arrays don't alias. Informing llvm can make it easier to vectorize. if not inplace: # The first argument is the output pointer arg = builder.function.args[0] @@ -436,23 +339,30 @@ def make_loop_call( input_types: tuple[Any, ...], output_types: tuple[Any, ...], core_scalar: bool = True, + input_read_spec: tuple[tuple[int, int] | None, ...] | None = None, + idx_arrs: list | None = None, + idx_types: tuple | None = None, + idx_load_axis: tuple[int, ...] | None = None, + idx_bc: tuple[tuple[bool, ...], ...] | None = None, ): safe = (False, False) n_outputs = len(outputs) - # TODO I think this is better than the noalias attribute - # for the input, but self_ref isn't supported in a released - # llvmlite version yet - # mod = builder.module - # domain = mod.add_metadata([], self_ref=True) - # input_scope = mod.add_metadata([domain], self_ref=True) - # output_scope = mod.add_metadata([domain], self_ref=True) - # input_scope_set = mod.add_metadata([input_scope, output_scope]) - # output_scope_set = mod.add_metadata([input_scope, output_scope]) - zero = ir.Constant(ir.IntType(64), 0) + def _wrap_negative_index(idx_val, dim_size, signed): + """Wrap a negative index by adding the dimension size: idx + size if idx < 0. + + Only emits the branch for signed index dtypes; unsigned indices are + returned as-is since they cannot be negative. + """ + if not signed: + return idx_val + is_neg = builder.icmp_signed("<", idx_val, zero) + wrapped = builder.add(idx_val, dim_size) + return builder.select(is_neg, wrapped, idx_val) + # Setup loops and initialize accumulators for outputs # This part corresponds to opening the loops loop_stack = [] @@ -475,14 +385,75 @@ def make_loop_call( # Code in the inner most loop... idxs = [loopval.index for loopval in loops] + # Load indirect indices for all index arrays. + # Each index array is 1D and is accessed by the loop counter for its axis. + indirect_idxs = [] + if idx_arrs is not None and idx_load_axis is not None: + for gi_k, (gi_arr, gi_type, ax) in enumerate( + zip(idx_arrs, idx_types, idx_load_axis) + ): + # Use zero if the index is statically broadcastable, loop counter otherwise + idx_is_bc = idx_bc[gi_k][0] if idx_bc and idx_bc[gi_k] else False + load_idx = zero if idx_is_bc else idxs[ax] + gi_ptr = cgutils.get_item_pointer2( + context, + builder, + gi_arr.data, + cgutils.unpack_tuple(builder, gi_arr.shape), + cgutils.unpack_tuple(builder, gi_arr.strides), + gi_type.layout, + [load_idx], + False, + False, + ) + val = builder.load(gi_ptr) + # Extend to i64 to match stride types in get_item_pointer2. + i64 = ir.IntType(64) + if val.type != i64: + if gi_type.dtype.signed: + val = builder.sext(val, i64) + else: + val = builder.zext(val, i64) + # Negative indices are wrapped at point of use (see + # _wrap_negative_index) since the dimension size depends on the + # source/target array being indexed. + indirect_idxs.append(val) + # Load values from input arrays input_vals = [] - for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True): + for input_i, (input, input_type, bc) in enumerate( + zip(inputs, input_types, input_bc, strict=True) + ): + spec = input_read_spec[input_i] if input_read_spec is not None else None core_ndim = input_type.ndim - len(bc) - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ - zero - ] * core_ndim + if spec is not None: + assert idx_types is not None + # Single-index on one axis: replace that axis with the indirect index + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + input_shape = cgutils.unpack_tuple(builder, input.shape) + idxs_bc = [] + result_dim = 0 + for src_dim in range(input_type.ndim): + if src_dim in indexed_axes: + idx_k = indexed_axes[src_dim] + idx_val = _wrap_negative_index( + indirect_idxs[idx_k], + input_shape[src_dim], + signed=idx_types[idx_k].dtype.signed, + ) + idxs_bc.append(idx_val) + result_dim += 1 + else: + if result_dim < len(bc): + idxs_bc.append(zero if bc[result_dim] else idxs[result_dim]) + else: + idxs_bc.append(zero) + result_dim += 1 + else: + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) + ] + [zero] * core_ndim ptr = cgutils.get_item_pointer2( context, builder, @@ -496,8 +467,6 @@ def make_loop_call( if core_scalar and core_ndim == 0: # Retrive scalar item at index val = builder.load(ptr) - # val.set_metadata("alias.scope", input_scope_set) - # val.set_metadata("noalias", output_scope_set) else: # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load` @@ -529,7 +498,9 @@ def make_loop_call( # Create output slices to pass to inner func output_slices = [] - for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True): + for output_i, (output, output_type, bc) in enumerate( + zip(outputs, output_types, output_bc, strict=True) + ): core_ndim = output_type.ndim - len(bc) size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue] output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue] @@ -548,7 +519,6 @@ def make_loop_call( idxs_bc, *safe, ) - # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load` core_arry_type = types.Array( @@ -583,3 +553,299 @@ def make_loop_call( loop.__exit__(None, None, None) return + + +@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) +def _vectorized( + typingctx, + core_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + constant_inputs_types, + outer_input_types, + output_core_shape_types, + size_type, + indexed_inputs, + indexed_outputs, +): + """Like _vectorized but with indirect indexing for reads. + + Outer inputs are ordered as + ``[elemwise_inputs..., idx_0, idx_1, ...]``. + + ``indexed_inputs`` groups elemwise input positions by which index they + read through: e.g. ``((0, 2), (1,))`` means idx_0 reads inputs 0 and 2, + idx_1 reads input 1. Entries may be empty ``()`` for update-only indices. + + ``indexed_outputs`` is unused (always empty tuple for reads-only). + + Parameters + ---------- + outer_input_types : tuple of Array types + ``(elemwise_input_0, ..., elemwise_input_N, idx_0, ...)`` + indexed_inputs : literal str + Encoded ``tuple[tuple[int, ...], ...]``. + indexed_outputs : literal str + Encoded ``tuple`` (always empty). + """ + arg_types = [ + core_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + constant_inputs_types, + outer_input_types, + output_core_shape_types, + size_type, + indexed_inputs, + indexed_outputs, + ] + + input_bc_patterns = _decode_literal(input_bc_patterns, "input_bc_patterns") + output_bc_patterns = _decode_literal(output_bc_patterns, "output_bc_patterns") + output_dtypes = _decode_literal(output_dtypes, "output_dtypes") + inplace_pattern = _decode_literal(inplace_pattern, "inplace_pattern") + indexed_inputs = _decode_literal(indexed_inputs, "indexed_inputs") + indexed_outputs = _decode_literal(indexed_outputs, "indexed_outputs") + + if not isinstance(allow_core_scalar, types.Literal): + raise TypingError("allow_core_scalar must be literal.") + allow_core_scalar = allow_core_scalar.literal_value + + n_indices = len(indexed_inputs) + n_elemwise = len(outer_input_types) - n_indices + source_input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) + idx_types = tuple(outer_input_types[n_elemwise + k] for k in range(n_indices)) + + # indexed_inputs entries are (positions, axis) -- one per index array. + # For multi-index (e.g. x[idx_row, idx_col]), an input appears in multiple + # entries with different axes. We aggregate per-input into a tuple of + # (idx_k, src_axis) pairs. + # + # idx_load_axis[k] = which loop dim loads index array k. + # For multi-index on consecutive axes starting at A, all arrays in the + # group load from loop dim A (the first dim of the group in the output). + _read_spec_dict: dict[int, list[tuple[int, int]]] = {} + idx_load_axis = [] + idx_bc_list = [] # per index array: broadcastable tuple + for k, entry in enumerate(indexed_inputs): + positions, axis = entry[0], entry[1] + idx_bc = entry[2] if len(entry) > 2 else (False,) + idx_bc_list.append(idx_bc) + for p in positions: + _read_spec_dict.setdefault(p, []).append((k, axis)) + # idx_load_axis[k] = which loop dim to use when loading index array k. + for k, entry in enumerate(indexed_inputs): + _positions, axis = entry[0], entry[1] + # Find if this index array shares an input with other indexed axes + # If so, the loop dim is the minimum axis in that group + min_axis = axis + for p in _positions: + if p in _read_spec_dict: + for _other_k, other_axis in _read_spec_dict[p]: + min_axis = min(min_axis, other_axis) + idx_load_axis.append(min_axis) + input_read_spec = tuple( + tuple(_read_spec_dict[p]) if p in _read_spec_dict else None + for p in range(n_elemwise) + ) + idx_load_axis = tuple(idx_load_axis) + + # Build effective input types that match input_bc_patterns ndim. + # For multi-indexed inputs, the source has more dims than the bc pattern + # (multiple source axes collapse into fewer loop dims). + input_types = [] + for p, src_type in enumerate(source_input_types): + spec = input_read_spec[p] + if spec is not None and len(spec) > 1: + # Multi-index: effective ndim = source ndim - n_indexed + 1 + effective_ndim = src_type.ndim - len(spec) + 1 + input_types.append( + types.Array(src_type.dtype, effective_ndim, src_type.layout) + ) + else: + input_types.append(src_type) + input_types = tuple(input_types) + + core_input_types, core_out_types, _out_types, ret_type = _compute_vectorized_types( + input_types, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + output_core_shape_types, + ) + + sig = ret_type(*arg_types) + + size_is_none = isinstance(size_type, NoneType) + + # Save values for codegen closure + input_bc_patterns_val = input_bc_patterns + output_bc_patterns_val = output_bc_patterns + output_dtypes_val = output_dtypes + inplace_pattern_val = inplace_pattern + input_read_spec_val = input_read_spec + idx_types_val = idx_types + idx_load_axis_val = idx_load_axis + idx_bc_list_val = idx_bc_list + + def codegen(ctx, builder, sig, args): + [ + _, + _, + _, + _, + _, + _, + constant_inputs, + outer_inputs, + output_core_shapes, + size, + _, + _, + ] = args + + constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) + all_outer = cgutils.unpack_tuple(builder, outer_inputs) + output_core_shapes = [ + cgutils.unpack_tuple(builder, shape) + for shape in cgutils.unpack_tuple(builder, output_core_shapes) + ] + size = None if size_is_none else cgutils.unpack_tuple(builder, size) + + # First n_elemwise outer inputs are elemwise inputs (source arrays) + inputs = [ + arrayobj.make_array(source_input_types[i])(ctx, builder, all_outer[i]) + for i in range(n_elemwise) + ] + in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] + + # Next n_indices inputs are index arrays + idx_arrs = [ + arrayobj.make_array(idx_types_val[k])( + ctx, builder, all_outer[n_elemwise + k] + ) + for k in range(n_indices) + ] + + # Build iter_shapes for compute_itershape. + # For indexed inputs, the source array may have more dims than the + # iteration shape (multi-index collapses multiple source axes into one + # loop dim). Replace the source shape with a constructed shape that + # matches the bc pattern: one entry per loop dim, with index lengths + # substituted for the indexed loop dim(s). + iter_shapes = list(in_shapes) + iter_bc = list(input_bc_patterns_val) + idx_shapes = [ + cgutils.unpack_tuple(builder, idx_arrs[k].shape) for k in range(n_indices) + ] + for p, spec in enumerate(input_read_spec_val): + if spec is None: + continue + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + n_indexed = len(indexed_axes) + if n_indexed <= 1: + # Single-index: substitute on the indexed axis + idx_k, axis = spec[0] + iter_shapes[p] = list(iter_shapes[p]) + iter_shapes[p][axis] = idx_shapes[idx_k][0] + else: + # Multi-index: build collapsed shape from non-indexed source + # dims + index length for the indexed loop dim. + source_shape = cgutils.unpack_tuple(builder, inputs[p].shape) + batch_ndim = len(input_bc_patterns_val[p]) + new_shape = [] + for loop_d in range(batch_ndim): + if loop_d == 0: + # First loop dim = index length + new_shape.append(idx_shapes[spec[0][0]][0]) + else: + # Remaining loop dims = non-indexed source axes + src_d = n_indexed - 1 + loop_d + new_shape.append(source_shape[src_d]) + iter_shapes[p] = new_shape + + # Each index array participates in iter_shape validation on its load + # axis, just like any direct input. Its bc comes from the PyTensor + # type (encoded in idx_bc_list), so shape-1 indices broadcast correctly. + batch_ndim = len(input_bc_patterns_val[0]) if input_bc_patterns_val else 0 + one = ir.IntType(64)(1) + for k in range(n_indices): + ax = idx_load_axis_val[k] + idx_shape_entry = [one] * batch_ndim + idx_shape_entry[ax] = idx_shapes[k][0] + iter_shapes.append(idx_shape_entry) + # bc on load axis from the index's own broadcastable; True elsewhere + idx_bc_on_ax = idx_bc_list_val[k][0] if idx_bc_list_val[k] else False + iter_bc.append( + tuple(True if d != ax else idx_bc_on_ax for d in range(batch_ndim)) + ) + + iter_shape = compute_itershape( + ctx, + builder, + iter_shapes, + tuple(iter_bc), + size, + ) + + outputs, output_types = make_outputs( + ctx, + builder, + iter_shape, + output_bc_patterns_val, + output_dtypes_val, + inplace_pattern_val, + inputs, + source_input_types, + output_core_shapes, + ) + + core_signature = typingctx.resolve_function_type( + core_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, + ) + + make_loop_call( + typingctx, + ctx, + builder, + core_func, + core_signature, + iter_shape, + constant_inputs, + inputs, + outputs, + input_bc_patterns_val, + output_bc_patterns_val, + source_input_types, + output_types, + core_scalar=allow_core_scalar, + input_read_spec=input_read_spec_val, + idx_arrs=idx_arrs, + idx_types=idx_types_val, + idx_load_axis=idx_load_axis_val, + idx_bc=idx_bc_list_val, + ) + + return _codegen_return_outputs( + ctx, + builder, + sig, + outputs, + inplace_pattern, + ) + + return sig, codegen diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index ddb7376fce..a17c5e70ce 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -381,14 +381,22 @@ def __setstate__(self, d): self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) + def make_scalar_node(self, *inputs): + """Create a scalar Apply node matching the dtypes of tensor inputs. + + Used by get_output_info, grad, and backend dispatchers to obtain + the scalar-level graph corresponding to this Elemwise operation. + """ + return self.scalar_op.make_node( + *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] + ) + def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. """ - shadow = self.scalar_op.make_node( - *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] - ) + shadow = self.make_scalar_node(*inputs) target_length = max(input.type.ndim for input in inputs) @@ -545,9 +553,7 @@ def as_scalar(t): scalar_inputs = list(map(as_scalar, inputs)) scalar_ograds = list(map(as_scalar, ograds)) - scalar_outputs = self.scalar_op.make_node( - *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] - ).outputs + scalar_outputs = self.make_scalar_node(*inputs).outputs scalar_igrads = self.scalar_op.L_op( scalar_inputs, scalar_outputs, scalar_ograds ) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 56cf7a3d8f..517819da98 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -75,6 +75,91 @@ def create_inplace_node( ) -> Apply: pass + def _get_protected_inputs(self, fgraph): + """Collect inputs protected from in-place destruction.""" + protected = set( + itertools.chain.from_iterable( + f.protected for f in fgraph._features if isinstance(f, Supervisor) + ) + ) + protected.update(fgraph.outputs) + return protected + + def try_inplace_on_node( + self, + fgraph, + node, + candidate_pairs=None, + reason="inplace_optimizer", + extra_protected_inputs=frozenset(), + ): + """Try to make a single node operate in-place. + + First attempts all candidate pairs at once, then falls back to + one-at-a-time. Returns the (possibly replaced) node. + + Parameters + ---------- + candidate_pairs + Pre-filtered and sorted list of ``((out_idx, out), (in_idx, inp))`` + pairs. If None, computed from ``filter_candidate_pairs`` using + the standard protections plus ``extra_protected_inputs``. + extra_protected_inputs + Additional variables that must not be used as in-place targets, + only used when ``candidate_pairs`` is None. + """ + if candidate_pairs is None: + protected_inputs = self._get_protected_inputs(fgraph) + protected_inputs.update(extra_protected_inputs) + candidate_pairs = self.filter_candidate_pairs( + fgraph, node, protected_inputs + ) + if not candidate_pairs: + return node + + # Try in-placing all outputs at once + tried_inputs = set() + inplace_pattern = {} + for (o, _), (i, _) in candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map == inplace_pattern: + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + except InconsistencyError: + pass + else: + copy_stack_trace(node.outputs, inplace_node.outputs) + return inplace_node + + # Fall back to one output/input at a time + tried_inputs = set() + inplace_pattern = {} + original_node = node + for (o, _), (i, _) in candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map != inplace_pattern: + # This Op can't respect this partial inplace pattern, + # We assume it can't support any other cases + break + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + node = inplace_node + except InconsistencyError: + inplace_pattern.pop(o) + if node is not original_node: + copy_stack_trace(original_node.outputs, node.outputs) + return node + def apply(self, fgraph): r""" @@ -128,12 +213,7 @@ def apply(self, fgraph): } large_graph = len(fgraph.apply_nodes) > 500 - protected_inputs = set( - itertools.chain.from_iterable( - f.protected for f in fgraph._features if isinstance(f, Supervisor) - ) - ) - protected_inputs.update(fgraph.outputs) + protected_inputs = self._get_protected_inputs(fgraph) root_destroyer = fgraph.destroy_handler.root_destroyer self_op = self.op @@ -164,14 +244,15 @@ def apply(self, fgraph): if not candidate_pairs: continue + # If the fgraph has updates, we try to prioritize in-placing + # on the pairs that correspond to the update sorted_candidate_pairs = candidate_pairs if op_updates and (node_updates := set(node.outputs) & set_op_updates): - # If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update direct_update_pairs = [] indirect_update_pairs = [] other_update_pairs = [] for pair in candidate_pairs: - ((o, out), (i, inp)) = pair + ((_o, out), (_i, inp)) = pair if out in node_updates: direct_update_inp = op_updates[out] if direct_update_inp is inp: @@ -191,54 +272,11 @@ def apply(self, fgraph): direct_update_pairs + indirect_update_pairs + other_update_pairs ) - # Try in-placing all outputs at once - tried_inputs = set() - inplace_pattern = {} - for (o, _), (i, _) in sorted_candidate_pairs: - if o not in inplace_pattern and i not in tried_inputs: - inplace_pattern[o] = [i] - tried_inputs.add(i) - - inplace_node = self.create_inplace_node(node, inplace_pattern) - if inplace_node.op.destroy_map == inplace_pattern: - replacements = tuple(zip(node.outputs, inplace_node.outputs)) - try: - fgraph.replace_all_validate(replacements, reason=reason) - except InconsistencyError: - prof["nb_eager_inconsistent"] += 1 - else: - prof["nb_replaced"] += 1 - copy_stack_trace(node.outputs, inplace_node.outputs) - continue - - # If it fails or doesn't match the desired inplace pattern, try one output/input at a time - tried_inputs = set() - inplace_pattern = {} - replaced = False - original_node = node - for (o, _), (i, _) in sorted_candidate_pairs: - if o not in inplace_pattern and i not in tried_inputs: - inplace_pattern[o] = [i] - tried_inputs.add(i) - - inplace_node = self.create_inplace_node(node, inplace_pattern) - if inplace_node.op.destroy_map != inplace_pattern: - # This Op can't respect this partial inplace pattern, - # We assume it can't support any other cases - break - else: - replacements = tuple(zip(node.outputs, inplace_node.outputs)) - try: - fgraph.replace_all_validate(replacements, reason=reason) - node = inplace_node - replaced = True - except InconsistencyError: - prof["nb_inconsistent"] += 1 - # The input, not the output caused inconsistencies - inplace_pattern.pop(o) - if replaced: - copy_stack_trace(original_node.outputs, node.outputs) - prof["nb_replaced"] += replaced + result = self.try_inplace_on_node( + fgraph, node, sorted_candidate_pairs, reason=reason + ) + if result is not node: + prof["nb_replaced"] += 1 return prof diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py new file mode 100644 index 0000000000..172e81ad5a --- /dev/null +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -0,0 +1,236 @@ +"""Fuse indexed reads into Elemwise iteration loops. + +Introduces ``IndexedElemwise``, an ``OpFromGraph`` that wraps +``AdvancedSubtensor1`` + ``Elemwise`` subgraphs so the Numba backend can +generate a single loop with indirect indexing, eliminating materialised +intermediate arrays. +""" + +from pytensor.compile import optdb +from pytensor.compile.builders import OpFromGraph +from pytensor.graph.rewriting.basic import GraphRewriter +from pytensor.graph.rewriting.db import SequenceDB +from pytensor.printing import op_debug_information +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.subtensor import ( + AdvancedSubtensor1, +) + + +indexed_elemwise_optdb = SequenceDB() +optdb.register( + "fuse_indexed_into_elemwise", + indexed_elemwise_optdb, + "numba", + # After inplace_elemwise (position=50.5) so we see final inplace patterns, + # same position as other numba-specific rewrites (BlockwiseWithCoreShape). + position=100, +) + + +class IndexedElemwise(OpFromGraph): + """Fuse indexed reads into a single Elemwise iteration loop. + + Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) into one loop, + avoiding materialisation of intermediate arrays. + + Inner fgraph contains the unfused subgraph. + Non-Numba backends run it as-is via ``OpFromGraph.perform``. + The Numba backend generates a single loop with indirect indexing. + + Outer inputs are ordered as:: + + [elemwise_inputs..., idx_0, idx_1, ...] + + Elemwise inputs whose values are read via an index have their source + arrays substituted in place. + + Parameters + ---------- + indexed_inputs : tuple of (tuple[int, ...], int, tuple[bool, ...]) + One entry per index array: ``(elemwise_input_positions, axis, idx_broadcastable)``. + """ + + def __init__(self, *args, indexed_inputs=(), **kwargs): + self.indexed_inputs = indexed_inputs + super().__init__(*args, on_unused_input="ignore", **kwargs) + + def __str__(self): + for node in self.fgraph.apply_nodes: + if isinstance(node.op, Elemwise): + return f"IndexedElemwise{{{node.op!s}}}" + return "IndexedElemwise" + + +@op_debug_information.register(IndexedElemwise) +def _op_debug_information_IndexedElemwise(op, node): + info = {} + + n_idx = len(op.indexed_inputs) + n_elemwise = len(node.inputs) - n_idx + + # Annotate indexed-read inputs + for k, (positions, _axis, _bc) in enumerate(op.indexed_inputs): + idx_label = f"idx_{k}" + for pos in positions: + if pos < len(node.inputs): + info[node.inputs[pos]] = f"indexed read ({idx_label})" + + # Annotate index arrays (after elemwise inputs) + for k in range(n_idx): + idx_pos = n_elemwise + k + if idx_pos < len(node.inputs): + info[node.inputs[idx_pos]] = f"idx_{k}" + + return {node: info} + + +class FuseIndexedElemwise(GraphRewriter): + """Fuse indexed reads into Elemwise loops. + + Absorbs single-client ``AdvancedSubtensor1`` on inputs (indexed reads) + into the Elemwise iteration, avoiding intermediate arrays. + + Supports multiple index arrays: e.g. ``x[idx_a] + y[idx_b]`` produces + two index groups. Index arrays are shared between reads when they refer + to the same variable. + """ + + def apply(self, fgraph): + def _get_indexed_read_info(var): + """Extract indexed-read info from a variable. + + Returns ``(source, [(idx_var, axis), ...])`` or ``None``. + Handles: + - ``AdvancedSubtensor1(source, idx)`` -> single index on axis 0 + """ + if var.owner is None: + return None + op = var.owner.op + if isinstance(op, AdvancedSubtensor1): + return (var.owner.inputs[0], [(var.owner.inputs[1], 0)]) + return None + + def find_indexed_input_groups(fgraph, node): + """Find single-client indexed-read inputs grouped by (index, axis). + + Returns ``[(idx_var, axis, (pos, ...))]`` -- one entry per distinct + ``(idx_var, axis)`` pair. + """ + groups = {} # (idx_var, axis) -> (idx_var, axis, list of positions) + for i, inp in enumerate(node.inputs): + info = _get_indexed_read_info(inp) + if info is None: + continue + if any(c is not node for c, _ in fgraph.clients[inp]): + continue + _source, idx_axis_pairs = info + for idx_var, axis in idx_axis_pairs: + key = (idx_var, axis) + if key not in groups: + groups[key] = (idx_var, axis, []) + groups[key][2].append(i) + + return [(var, axis, tuple(pos)) for var, axis, pos in groups.values()] + + for node in reversed(fgraph.toposort()): + if not isinstance(node.op, Elemwise): + continue + + read_groups = find_indexed_input_groups(fgraph, node) + + if not read_groups: + continue + + indexed_positions = { + p for _, _ax, positions in read_groups for p in positions + } + + # If any inplace targets an indexed-read input, strip and re-run + # inplace with those inputs protected + if any( + inp_idx in indexed_positions + for inp_idx in node.op.inplace_pattern.values() + ): + stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) + fgraph.replace_all_validate( + list(zip(node.outputs, stripped_node.outputs)), + reason="fuse_indexed_elemwise_strip_inplace", + ) + + protected = frozenset( + stripped_node.inputs[i] for i in indexed_positions + ) + node = InplaceElemwiseOptimizer().try_inplace_on_node( + fgraph, + stripped_node, + reason="fuse_indexed_elemwise_reinplace", + extra_protected_inputs=protected, + ) + + # Merge read index arrays into a unified ordered list + all_idx_groups = {} # (idx_var, axis) -> (idx_var, position) + for idx_var, _ax, _ in read_groups: + key = (idx_var, _ax) + if key not in all_idx_groups: + all_idx_groups[key] = (idx_var, len(all_idx_groups)) + + n_indices = len(all_idx_groups) + idx_vars = [None] * n_indices + for idx_var, pos in all_idx_groups.values(): + idx_vars[pos] = idx_var + + # Build destroy_map + outer_destroy_map = {} + for out_idx, inp_idx in node.op.inplace_pattern.items(): + if inp_idx not in indexed_positions: + outer_destroy_map[out_idx] = [inp_idx] + + # Inner fgraph inputs: + # [elemwise_inputs (sources substituted)..., idx_0, ...] + inner_inputs = [ + inp.owner.inputs[0] if i in indexed_positions else inp + for i, inp in enumerate(node.inputs) + ] + inner_inputs = inner_inputs + idx_vars + + outer_inputs = list(inner_inputs) + + # Build indexed_inputs spec for the Op + indexed_inputs_spec = [None] * n_indices + for idx_var, _ax, positions in read_groups: + key = (idx_var, _ax) + _var, idx_pos = all_idx_groups[key] + indexed_inputs_spec[idx_pos] = ( + positions, + _ax, + idx_var.type.broadcastable, + ) + # Fill entries for index arrays with no reads + for k in range(n_indices): + if indexed_inputs_spec[k] is None: + for (iv, ax), (v, p) in all_idx_groups.items(): + if p == k: + indexed_inputs_spec[k] = ((), ax, iv.type.broadcastable) + break + + new_outs = IndexedElemwise( + inner_inputs, + node.outputs, + destroy_map=outer_destroy_map, + indexed_inputs=tuple(indexed_inputs_spec), + )(*outer_inputs, return_list=True) + + fgraph.replace_all_validate( + list(zip(node.outputs, new_outs)), + reason="fuse_indexed_into_elemwise", + ) + + +indexed_elemwise_optdb.register( + "fuse_indexed_elemwise", + FuseIndexedElemwise(), + "numba", + position=1, +) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 6bb9ed5bd9..88714e0b27 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -1,3 +1,5 @@ +# Import to trigger registration of IndexedElemwise rewrites +import pytensor.tensor.rewriting.indexed_elemwise # noqa: F401 from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter diff --git a/tests/benchmarks/test_gather_fusion.py b/tests/benchmarks/test_gather_fusion.py new file mode 100644 index 0000000000..04525a292d --- /dev/null +++ b/tests/benchmarks/test_gather_fusion.py @@ -0,0 +1,75 @@ +"""Micro-benchmarks for Elemwise fusion with indexed reads. + +Tests the benefit of fusing AdvancedSubtensor1 (indexed reads) into Elemwise +loops, avoiding materialization of intermediate arrays. +""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import config +from pytensor.compile.mode import get_mode +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.subtensor import advanced_subtensor1 + + +@pytest.fixture( + params=[ + (85, 919, 2, 6), # radon-like: small + (1000, 100_000, 2, 4), # medium + ], + ids=["small-85x919", "medium-1Kx100K"], +) +def gather_benchmark_setup(request): + n_bins, n_data, n_gathered, n_direct = request.param + + rng = np.random.default_rng(42) + idx = rng.integers(n_bins, size=n_data).astype(np.int64) + idx.sort() + + sources = [pt.vector(f"src_{i}", shape=(n_bins,)) for i in range(n_gathered)] + directs = [pt.vector(f"dir_{i}", shape=(n_data,)) for i in range(n_direct)] + + terms = [advanced_subtensor1(s, idx) for s in sources] + directs + out = terms[0] + for t in terms[1:]: + out = out + t + + inputs = sources + directs + numba_mode = get_mode("NUMBA") + + fn_fused = pytensor.function(inputs, out, mode=numba_mode, trust_input=True) + fn_unfused = pytensor.function( + inputs, + out, + mode=numba_mode.excluding("fuse_indexed_into_elemwise"), + trust_input=True, + ) + + assert any( + isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "IndexedElemwise not found in fused graph" + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "IndexedElemwise found in unfused graph" + + rng = np.random.default_rng(1) + vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] + + np.testing.assert_allclose(fn_fused(*vals), fn_unfused(*vals), rtol=1e-10) + + return fn_fused, fn_unfused, vals + + +def test_gather_fusion_fused(gather_benchmark_setup, benchmark): + fn_fused, _, vals = gather_benchmark_setup + fn_fused(*vals) # warmup + benchmark(fn_fused, *vals) + + +def test_gather_fusion_unfused(gather_benchmark_setup, benchmark): + _, fn_unfused, vals = gather_benchmark_setup + fn_unfused(*vals) # warmup + benchmark(fn_unfused, *vals) diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py new file mode 100644 index 0000000000..f5ea8bea15 --- /dev/null +++ b/tests/link/numba/test_indexed_elemwise.py @@ -0,0 +1,125 @@ +"""Tests for IndexedElemwise fusion (indexed reads in Elemwise loops).""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.compile.mode import get_mode +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.subtensor import advanced_subtensor1 + + +numba = pytest.importorskip("numba") + +NUMBA_MODE = get_mode("NUMBA") +NUMBA_NO_FUSION = NUMBA_MODE.excluding("fuse_indexed_into_elemwise") + + +def fused_and_unfused(inputs, output): + """Compile fused and unfused versions of a graph.""" + fn = pytensor.function(inputs, output, mode=NUMBA_MODE, trust_input=True) + fn_u = pytensor.function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True) + return fn, fn_u + + +def assert_fused(fn): + """Assert that the compiled graph contains an IndexedElemwise node.""" + assert any(isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort()), ( + "IndexedElemwise not found in fused graph" + ) + + +class TestIndexedReadFusion: + """Test indexed reads (AdvancedSubtensor1) fused into Elemwise.""" + + def test_single_index_axis0(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + fn, fn_u = fused_and_unfused([x, y], advanced_subtensor1(x, idx) + y) + assert_fused(fn) + xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_multiple_gathered_sources(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x1 = pt.vector("x1", shape=(85,)) + x2 = pt.vector("x2", shape=(85,)) + y = pt.vector("y", shape=(919,)) + fn, fn_u = fused_and_unfused( + [x1, x2, y], advanced_subtensor1(x1, idx) + advanced_subtensor1(x2, idx) + y + ) + assert_fused(fn) + xv1, xv2, yv = ( + rng.normal(size=(85,)), + rng.normal(size=(85,)), + rng.normal(size=(919,)), + ) + np.testing.assert_allclose(fn(xv1, xv2, yv), fn_u(xv1, xv2, yv), rtol=1e-10) + + def test_broadcast_index_axis0(self): + """Static shape=(1,) index on axis 0 broadcasts against larger direct input.""" + x = pt.vector("x", shape=(100,)) + y = pt.vector("y", shape=(50,)) + idx = np.array([5], dtype=np.int64) # shape (1,), broadcastable + out = advanced_subtensor1(x, idx) + y + fn, fn_u = fused_and_unfused([x, y], out) + assert_fused(fn) + xv = np.arange(100, dtype="float64") + yv = np.ones(50) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_negative_indices(self): + """Negative indices must be handled correctly (sign-extended, not zero-extended).""" + rng = np.random.default_rng(42) + # Use unknown static shape so negative indices can't be canonicalized away + x = pt.vector("x") + y = pt.vector("y") + idx = pt.vector("idx", dtype="int64") + fn, fn_u = fused_and_unfused([x, idx, y], x[idx] + y) + assert_fused(fn) + xv = rng.normal(size=100) + # Negative indices: -1 means last element, -2 second to last, etc. + idxv = np.array([-1, -2, -3, 0, 1], dtype=np.int64) + yv = rng.normal(size=5) + np.testing.assert_allclose(fn(xv, idxv, yv), fn_u(xv, idxv, yv), rtol=1e-10) + + +class TestShapeValidation: + """Test that mismatched index/input shapes raise runtime errors. + + All inputs use ``shape=(None,)`` so shapes are unknown at compile time. + The fused loop's ``compute_itershape`` must catch mismatches at runtime. + """ + + def test_mismatched_index_and_direct_input(self): + """Index length doesn't match direct input on the same loop dim.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = x[idx] + y + fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: idx=50, y=50 — should work + fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(50)) + # Mismatched: idx=50, y=49 — should error + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(49)) + + def test_runtime_broadcast_on_index_dim(self): + """Symbolic shapes that happen to be 1 at runtime — broadcast check.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = x[idx] + y + fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Both idx and y have length 1 — should work (both agree on dim 0) + result = fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(1)) + assert result.shape == (1,) + # idx=1, y=5 — should error (shape mismatch, no static broadcast info) + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) From e192f35f55ae568d752017c2f8b0013870572e83 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 30 Mar 2026 20:09:17 +0200 Subject: [PATCH 3/7] Numba: fuse AdvancedIncSubtensor1 indexed updates into Elemwise Extend the IndexedElemwise fusion to also absorb AdvancedIncSubtensor1 (indexed set/inc) on the output side. Before (3 nodes): temp = Elemwise(x[idx], y) # shape (919,) result = IncSubtensor(target, temp, idx) # target shape (85,) After (1 fused loop, target is an input): for k in range(919): target[idx[k]] += scalar_fn(x[idx[k]], y[k]) - FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers - Reject fusion when val broadcasts against target's non-indexed axes - store_core_outputs supports inc mode via o[...] += val - Inner fgraph always uses inplace IncSubtensor - op_debug_information shows buf_N / idx_N linkage --- pytensor/link/numba/dispatch/elemwise.py | 42 +- .../link/numba/dispatch/vectorize_codegen.py | 280 ++++++--- pytensor/tensor/rewriting/indexed_elemwise.py | 222 ++++++- tests/benchmarks/test_gather_fusion.py | 85 ++- tests/link/numba/test_indexed_elemwise.py | 591 +++++++++++++++++- 5 files changed, 1101 insertions(+), 119 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 393d4d6266..17f26b4860 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -397,19 +397,35 @@ def impl(*inputs): @register_funcify_and_cache_key(IndexedElemwise) def numba_funcify_IndexedElemwise(op, node, **kwargs): - """Generate fused Elemwise Numba code with indexed reads. + """Generate fused Elemwise Numba code with indexed reads and updates. - Reads indexed_inputs specs stored on the Op by the rewriting pass, - and generates a single vectorized loop with indirect indexing. + Reads indexed_inputs/indexed_outputs specs stored on the Op by the + rewriting pass, and generates a single vectorized loop with indirect + indexing. Outer inputs are ordered as:: - [elemwise_inputs..., idx_0, idx_1, ...] + [elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...] """ [elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)] indexed_inputs = op.indexed_inputs - indexed_outputs_enc = encode_literals(()) + indexed_outputs = op.indexed_outputs + + # Derive update_out_set and inc_outputs from stored specs + update_out_set = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None + for out_idx in entry[0] + ) + inc_outputs = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None + for out_idx in entry[0] + if entry[1] == "inc" + ) # --- Scalar function and core op -------------------------------------- scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs) @@ -419,13 +435,21 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): nin_elemwise = len(elemwise_node.inputs) nout = len(elemwise_node.outputs) - core_op_fn = store_core_outputs(scalar_op_fn, nin=nin_elemwise, nout=nout) + + core_op_fn = store_core_outputs( + scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + ) # --- Broadcast and type encodings ------------------------------------- input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs) - inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) + # Filter out inplace entries for update outputs (handled by scatter) + inplace_pattern = tuple( + (out_idx, inp_idx) + for out_idx, inp_idx in elemwise_node.op.inplace_pattern.items() + if out_idx not in update_out_set + ) core_output_shapes = tuple(() for _ in range(nout)) input_bc_patterns_enc = encode_literals(input_bc_patterns) @@ -433,6 +457,7 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): output_dtypes_enc = encode_literals(output_dtypes) inplace_pattern_enc = encode_literals(inplace_pattern) indexed_inputs_enc = encode_literals(indexed_inputs) + indexed_outputs_enc = encode_literals(indexed_outputs) def indexed_elemwise_fn(*outer_inputs): raise NotImplementedError( @@ -459,7 +484,7 @@ def impl(*outer_inputs): return impl - cache_version = (0, 1) + cache_version = (0, 2) if scalar_cache_key is None: key = None else: @@ -471,6 +496,7 @@ def impl(*outer_inputs): inplace_pattern, input_bc_patterns, indexed_inputs, + indexed_outputs, scalar_cache_key, ) ) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 3f74dad09e..a4d85d93b4 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -41,14 +41,23 @@ def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: +def store_core_outputs( + core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() +) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) - o0[...] = to0 + o0[...] = to0 # direct outputs + o1[...] += to1 # inc outputs (indexed update) ... + + Parameters + ---------- + inc_outputs : frozenset + Output indices that use ``+=`` (increment) instead of ``=`` (assign). + Used for indexed-update outputs in fused loops. """ inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] @@ -58,8 +67,12 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): out_signature = ", ".join(outputs) inner_out_signature = ", ".join(inner_outputs) store_outputs = "\n".join( - f"{output}[...] = {inner_output}" - for output, inner_output in zip(outputs, inner_outputs, strict=True) + f"{output}[...] += {inner_output}" + if i in inc_outputs + else f"{output}[...] = {inner_output}" + for i, (output, inner_output) in enumerate( + zip(outputs, inner_outputs, strict=True) + ) ) func_src = f""" def store_core_outputs({inp_signature}, {out_signature}): @@ -160,9 +173,18 @@ def _compute_vectorized_types( return core_input_types, core_out_types, out_types, ret_type -def _codegen_return_outputs(ctx, builder, sig, outputs, inplace_pattern): - """Generate LLVM IR to return output arrays, handling incref for inplace.""" - incref_set = set(dict(inplace_pattern).keys()) +def _codegen_return_outputs( + ctx, builder, sig, outputs, inplace_pattern, extra_incref=frozenset() +): + """Generate LLVM IR to return output arrays, handling incref for inplace. + + Parameters + ---------- + extra_incref : frozenset + Additional output indices (e.g. indexed-update outputs) that alias an input + buffer and need an incref before returning, beyond inplace outputs. + """ + incref_set = set(dict(inplace_pattern).keys()) | set(extra_incref) if len(outputs) == 1: if incref_set: @@ -287,8 +309,16 @@ def make_outputs( inputs: tuple[Any, ...], input_types: tuple[Any, ...], output_core_shapes: tuple, + update_outputs: dict | None = None, ) -> tuple[list[ir.Value], list[types.Array]]: - """Allocate output arrays for vectorized loop.""" + """Allocate output arrays for vectorized loop. + + Parameters + ---------- + update_outputs : dict, optional + Mapping ``{output_idx: (array, array_type)}`` for outputs that reuse + a scatter-target input buffer instead of being freshly allocated. + """ output_arrays = [] output_arry_types = [] one = ir.IntType(64)(1) @@ -296,6 +326,10 @@ def make_outputs( for i, (core_shape, bc, dtype) in enumerate( zip(output_core_shapes, out_bc, dtypes, strict=True) ): + if update_outputs is not None and i in update_outputs: + output_arrays.append(update_outputs[i][0]) + output_arry_types.append(update_outputs[i][1]) + continue if i in inplace_dict: output_arrays.append(inputs[inplace_dict[i]]) output_arry_types.append(input_types[inplace_dict[i]]) @@ -315,9 +349,9 @@ def make_outputs( array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) output_arrays.append(array) - # If there is no inplace operation, we know that all output + # If there is no inplace or scatter operation, we know that all output # arrays don't alias. Informing llvm can make it easier to vectorize. - if not inplace: + if not inplace and not update_outputs: # The first argument is the output pointer arg = builder.function.args[0] arg.add_attribute("noalias") @@ -344,11 +378,22 @@ def make_loop_call( idx_types: tuple | None = None, idx_load_axis: tuple[int, ...] | None = None, idx_bc: tuple[tuple[bool, ...], ...] | None = None, + output_update_spec: tuple[tuple[int, int] | None, ...] | None = None, ): safe = (False, False) n_outputs = len(outputs) + # TODO I think this is better than the noalias attribute + # for the input, but self_ref isn't supported in a released + # llvmlite version yet + # mod = builder.module + # domain = mod.add_metadata([], self_ref=True) + # input_scope = mod.add_metadata([domain], self_ref=True) + # output_scope = mod.add_metadata([domain], self_ref=True) + # input_scope_set = mod.add_metadata([input_scope, output_scope]) + # output_scope_set = mod.add_metadata([input_scope, output_scope]) + zero = ir.Constant(ir.IntType(64), 0) def _wrap_negative_index(idx_val, dim_size, signed): @@ -467,6 +512,8 @@ def _wrap_negative_index(idx_val, dim_size, signed): if core_scalar and core_ndim == 0: # Retrive scalar item at index val = builder.load(ptr) + # val.set_metadata("alias.scope", input_scope_set) + # val.set_metadata("noalias", output_scope_set) else: # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load` @@ -506,9 +553,33 @@ def _wrap_negative_index(idx_val, dim_size, signed): output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue] output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue] - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ - zero - ] * core_ndim + spec = output_update_spec[output_i] if output_update_spec is not None else None + if spec is not None: + assert idx_types is not None + # Indexed-update output: same logic as indexed-read input + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + n_indexed = len(indexed_axes) + source_batch_ndim = len(bc) + n_indexed - 1 + idxs_bc = [] + loop_dim = 0 + for src_dim in range(source_batch_ndim): + if src_dim in indexed_axes: + idx_k = indexed_axes[src_dim] + idx_val = _wrap_negative_index( + indirect_idxs[idx_k], + output_shape[src_dim], + signed=idx_types[idx_k].dtype.signed, + ) + idxs_bc.append(idx_val) + else: + bc_dim = bc[loop_dim] if loop_dim < len(bc) else False + idxs_bc.append(zero if bc_dim else idxs[loop_dim]) + loop_dim += 1 + idxs_bc += [zero] * core_ndim + else: + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) + ] + [zero] * core_ndim ptr = cgutils.get_item_pointer2( context, builder, @@ -571,25 +642,28 @@ def _vectorized( indexed_inputs, indexed_outputs, ): - """Like _vectorized but with indirect indexing for reads. + """Like _vectorized but with indirect indexing for reads and updates. Outer inputs are ordered as - ``[elemwise_inputs..., idx_0, idx_1, ...]``. + ``[elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...]``. ``indexed_inputs`` groups elemwise input positions by which index they read through: e.g. ``((0, 2), (1,))`` means idx_0 reads inputs 0 and 2, idx_1 reads input 1. Entries may be empty ``()`` for update-only indices. - ``indexed_outputs`` is unused (always empty tuple for reads-only). + ``indexed_outputs`` has one entry per index array (same length as + ``indexed_inputs``). ``None`` means that index is not used for updates. + ``((out_0, out_1), mode)`` means that index updates outputs out_0 and + out_1 with *mode* ``"set"`` or ``"inc"``. Parameters ---------- outer_input_types : tuple of Array types - ``(elemwise_input_0, ..., elemwise_input_N, idx_0, ...)`` + ``(elemwise_input_0, ..., elemwise_input_N, idx_0, ..., update_target_0, ...)`` indexed_inputs : literal str Encoded ``tuple[tuple[int, ...], ...]``. indexed_outputs : literal str - Encoded ``tuple`` (always empty). + Encoded ``tuple[tuple[tuple[int, ...], str] | None, ...]``. """ arg_types = [ core_func, @@ -617,19 +691,23 @@ def _vectorized( raise TypingError("allow_core_scalar must be literal.") allow_core_scalar = allow_core_scalar.literal_value + # Count scatter targets (one per scattered output) + n_update_targets = sum( + len(entry[0]) for entry in indexed_outputs if entry is not None + ) + n_indices = len(indexed_inputs) - n_elemwise = len(outer_input_types) - n_indices - source_input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) + n_elemwise = len(outer_input_types) - n_indices - n_update_targets + input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) idx_types = tuple(outer_input_types[n_elemwise + k] for k in range(n_indices)) + update_target_types = tuple( + outer_input_types[n_elemwise + n_indices + j] for j in range(n_update_targets) + ) - # indexed_inputs entries are (positions, axis) -- one per index array. - # For multi-index (e.g. x[idx_row, idx_col]), an input appears in multiple - # entries with different axes. We aggregate per-input into a tuple of - # (idx_k, src_axis) pairs. + # indexed_inputs entries are (positions, axis) — one per index array. + # We aggregate per-input into a tuple of (idx_k, src_axis) pairs. # # idx_load_axis[k] = which loop dim loads index array k. - # For multi-index on consecutive axes starting at A, all arrays in the - # group load from loop dim A (the first dim of the group in the output). _read_spec_dict: dict[int, list[tuple[int, int]]] = {} idx_load_axis = [] idx_bc_list = [] # per index array: broadcastable tuple @@ -642,37 +720,36 @@ def _vectorized( # idx_load_axis[k] = which loop dim to use when loading index array k. for k, entry in enumerate(indexed_inputs): _positions, axis = entry[0], entry[1] - # Find if this index array shares an input with other indexed axes - # If so, the loop dim is the minimum axis in that group - min_axis = axis - for p in _positions: - if p in _read_spec_dict: - for _other_k, other_axis in _read_spec_dict[p]: - min_axis = min(min_axis, other_axis) - idx_load_axis.append(min_axis) + idx_load_axis.append(axis) input_read_spec = tuple( tuple(_read_spec_dict[p]) if p in _read_spec_dict else None for p in range(n_elemwise) ) idx_load_axis = tuple(idx_load_axis) - # Build effective input types that match input_bc_patterns ndim. - # For multi-indexed inputs, the source has more dims than the bc pattern - # (multiple source axes collapse into fewer loop dims). - input_types = [] - for p, src_type in enumerate(source_input_types): - spec = input_read_spec[p] - if spec is not None and len(spec) > 1: - # Multi-index: effective ndim = source ndim - n_indexed + 1 - effective_ndim = src_type.ndim - len(spec) + 1 - input_types.append( - types.Array(src_type.dtype, effective_ndim, src_type.layout) - ) - else: - input_types.append(src_type) - input_types = tuple(input_types) + # Per-output: tuple of (idx_k, axis) pairs, or None. + # Same format as input_read_spec. + # indexed_outputs entries are (positions, mode, axis) or None. + _update_spec_dict: dict[int, list[tuple[int, int]]] = {} + update_out_to_target = {} + update_out_indices = set() + target_counter = 0 + for k, entry in enumerate(indexed_outputs): + if entry is None: + continue + output_indices, _mode, axis = entry + for out_idx in output_indices: + _update_spec_dict.setdefault(out_idx, []).append((k, axis)) + update_out_to_target[out_idx] = target_counter + update_out_indices.add(out_idx) + target_counter += 1 + output_update_spec = tuple( + tuple(_update_spec_dict[p]) if p in _update_spec_dict else None + for p in range(len(output_bc_patterns)) + ) + update_out_indices = frozenset(update_out_indices) - core_input_types, core_out_types, _out_types, ret_type = _compute_vectorized_types( + core_input_types, core_out_types, out_types, ret_type = _compute_vectorized_types( input_types, input_bc_patterns, output_bc_patterns, @@ -682,6 +759,27 @@ def _vectorized( output_core_shape_types, ) + # Fix up output types for scattered outputs: they match the target buffer + if update_out_to_target: + core_out_types = list(core_out_types) + out_types = list(out_types) + batch_ndim = len(input_bc_patterns[0]) + for out_idx, target_idx in update_out_to_target.items(): + target_type = update_target_types[target_idx] + out_types[out_idx] = target_type + core_out_types[out_idx] = types.Array( + dtype=target_type.dtype, + ndim=target_type.ndim - batch_ndim, + layout=target_type.layout, + ) + out_types = tuple(out_types) + core_out_types = tuple(core_out_types) + + if len(out_types) == 1: + ret_type = out_types[0] + else: + ret_type = types.Tuple(out_types) + sig = ret_type(*arg_types) size_is_none = isinstance(size_type, NoneType) @@ -695,6 +793,10 @@ def _vectorized( idx_types_val = idx_types idx_load_axis_val = idx_load_axis idx_bc_list_val = idx_bc_list + output_update_spec_val = output_update_spec + update_out_to_target_val = update_out_to_target + update_target_types_val = update_target_types + update_out_indices_val = update_out_indices def codegen(ctx, builder, sig, args): [ @@ -722,7 +824,7 @@ def codegen(ctx, builder, sig, args): # First n_elemwise outer inputs are elemwise inputs (source arrays) inputs = [ - arrayobj.make_array(source_input_types[i])(ctx, builder, all_outer[i]) + arrayobj.make_array(input_types[i])(ctx, builder, all_outer[i]) for i in range(n_elemwise) ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] @@ -735,12 +837,17 @@ def codegen(ctx, builder, sig, args): for k in range(n_indices) ] + # Remaining inputs are scatter target buffers + update_target_arrs = [ + arrayobj.make_array(update_target_types_val[j])( + ctx, builder, all_outer[n_elemwise + n_indices + j] + ) + for j in range(n_update_targets) + ] + # Build iter_shapes for compute_itershape. - # For indexed inputs, the source array may have more dims than the - # iteration shape (multi-index collapses multiple source axes into one - # loop dim). Replace the source shape with a constructed shape that - # matches the bc pattern: one entry per loop dim, with index lengths - # substituted for the indexed loop dim(s). + # For indexed inputs, substitute the index array length on the + # indexed axis of the source shape. iter_shapes = list(in_shapes) iter_bc = list(input_bc_patterns_val) idx_shapes = [ @@ -749,41 +856,32 @@ def codegen(ctx, builder, sig, args): for p, spec in enumerate(input_read_spec_val): if spec is None: continue - indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} - n_indexed = len(indexed_axes) - if n_indexed <= 1: - # Single-index: substitute on the indexed axis - idx_k, axis = spec[0] - iter_shapes[p] = list(iter_shapes[p]) - iter_shapes[p][axis] = idx_shapes[idx_k][0] - else: - # Multi-index: build collapsed shape from non-indexed source - # dims + index length for the indexed loop dim. - source_shape = cgutils.unpack_tuple(builder, inputs[p].shape) - batch_ndim = len(input_bc_patterns_val[p]) - new_shape = [] - for loop_d in range(batch_ndim): - if loop_d == 0: - # First loop dim = index length - new_shape.append(idx_shapes[spec[0][0]][0]) - else: - # Remaining loop dims = non-indexed source axes - src_d = n_indexed - 1 + loop_d - new_shape.append(source_shape[src_d]) - iter_shapes[p] = new_shape + # Single-index: substitute on the indexed axis + idx_k, axis = spec[0] + iter_shapes[p] = list(iter_shapes[p]) + iter_shapes[p][axis] = idx_shapes[idx_k][0] # Each index array participates in iter_shape validation on its load - # axis, just like any direct input. Its bc comes from the PyTensor - # type (encoded in idx_bc_list), so shape-1 indices broadcast correctly. + # axis, just like any direct input. + # + # Read indices use their static broadcastable so shape-1 indices + # broadcast correctly against other inputs. + # + # Write indices are forced to bc=False: unlike reads, numpy does + # not allow the index to broadcast against the update value. A + # shape-1 write index at runtime must match a size-1 loop, not + # silently repeat writes to the same target position. batch_ndim = len(input_bc_patterns_val[0]) if input_bc_patterns_val else 0 one = ir.IntType(64)(1) for k in range(n_indices): ax = idx_load_axis_val[k] + is_write = indexed_outputs[k] is not None idx_shape_entry = [one] * batch_ndim idx_shape_entry[ax] = idx_shapes[k][0] iter_shapes.append(idx_shape_entry) - # bc on load axis from the index's own broadcastable; True elsewhere idx_bc_on_ax = idx_bc_list_val[k][0] if idx_bc_list_val[k] else False + if is_write: + idx_bc_on_ax = False iter_bc.append( tuple(True if d != ax else idx_bc_on_ax for d in range(batch_ndim)) ) @@ -796,6 +894,19 @@ def codegen(ctx, builder, sig, args): size, ) + # Build update_outputs dict for make_outputs: out_idx -> (array, type) + update_outputs_dict = ( + { + out_idx: ( + update_target_arrs[target_idx], + update_target_types_val[target_idx], + ) + for out_idx, target_idx in update_out_to_target_val.items() + } + if update_out_to_target_val + else None + ) + outputs, output_types = make_outputs( ctx, builder, @@ -804,8 +915,9 @@ def codegen(ctx, builder, sig, args): output_dtypes_val, inplace_pattern_val, inputs, - source_input_types, + input_types, output_core_shapes, + update_outputs=update_outputs_dict, ) core_signature = typingctx.resolve_function_type( @@ -830,7 +942,7 @@ def codegen(ctx, builder, sig, args): outputs, input_bc_patterns_val, output_bc_patterns_val, - source_input_types, + input_types, output_types, core_scalar=allow_core_scalar, input_read_spec=input_read_spec_val, @@ -838,6 +950,7 @@ def codegen(ctx, builder, sig, args): idx_types=idx_types_val, idx_load_axis=idx_load_axis_val, idx_bc=idx_bc_list_val, + output_update_spec=output_update_spec_val, ) return _codegen_return_outputs( @@ -846,6 +959,7 @@ def codegen(ctx, builder, sig, args): sig, outputs, inplace_pattern, + extra_incref=update_out_indices_val, ) return sig, codegen diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index 172e81ad5a..d180530a16 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -1,9 +1,9 @@ -"""Fuse indexed reads into Elemwise iteration loops. +"""Fuse indexed reads and updates into Elemwise iteration loops. Introduces ``IndexedElemwise``, an ``OpFromGraph`` that wraps -``AdvancedSubtensor1`` + ``Elemwise`` subgraphs so the Numba backend can -generate a single loop with indirect indexing, eliminating materialised -intermediate arrays. +``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` subgraphs +so the Numba backend can generate a single loop with indirect indexing, +eliminating materialised intermediate arrays. """ from pytensor.compile import optdb @@ -11,9 +11,11 @@ from pytensor.graph.rewriting.basic import GraphRewriter from pytensor.graph.rewriting.db import SequenceDB from pytensor.printing import op_debug_information +from pytensor.scalar.basic import Composite, identity from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor1, AdvancedSubtensor1, ) @@ -30,9 +32,10 @@ class IndexedElemwise(OpFromGraph): - """Fuse indexed reads into a single Elemwise iteration loop. + """Fuse indexed reads and updates into a single Elemwise iteration loop. - Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) into one loop, + Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) and + ``AdvancedIncSubtensor1`` (indexed updates on outputs) into one loop, avoiding materialisation of intermediate arrays. Inner fgraph contains the unfused subgraph. @@ -41,7 +44,7 @@ class IndexedElemwise(OpFromGraph): Outer inputs are ordered as:: - [elemwise_inputs..., idx_0, idx_1, ...] + [elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...] Elemwise inputs whose values are read via an index have their source arrays substituted in place. @@ -50,10 +53,13 @@ class IndexedElemwise(OpFromGraph): ---------- indexed_inputs : tuple of (tuple[int, ...], int, tuple[bool, ...]) One entry per index array: ``(elemwise_input_positions, axis, idx_broadcastable)``. + indexed_outputs : tuple of ((tuple[int, ...], str, int) | None) + One entry per index array: ``(output_positions, mode, axis)`` or ``None``. """ - def __init__(self, *args, indexed_inputs=(), **kwargs): + def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): self.indexed_inputs = indexed_inputs + self.indexed_outputs = indexed_outputs super().__init__(*args, on_unused_input="ignore", **kwargs) def __str__(self): @@ -68,7 +74,8 @@ def _op_debug_information_IndexedElemwise(op, node): info = {} n_idx = len(op.indexed_inputs) - n_elemwise = len(node.inputs) - n_idx + n_update_targets = sum(1 for e in op.indexed_outputs if e is not None) + n_elemwise = len(node.inputs) - n_idx - n_update_targets # Annotate indexed-read inputs for k, (positions, _axis, _bc) in enumerate(op.indexed_inputs): @@ -83,18 +90,42 @@ def _op_debug_information_IndexedElemwise(op, node): if idx_pos < len(node.inputs): info[node.inputs[idx_pos]] = f"idx_{k}" + # Annotate update targets and outputs + buf_counter = 0 + target_start = n_elemwise + n_idx + target_offset = 0 + for k, entry in enumerate(op.indexed_outputs): + if entry is None: + continue + out_positions, mode, _axis = entry + buf_label = f"buf_{buf_counter}" + buf_counter += 1 + idx_label = f"idx_{k}" + + target_pos = target_start + target_offset + target_offset += 1 + if target_pos < len(node.inputs): + info[node.inputs[target_pos]] = buf_label + + for out_idx in out_positions: + if out_idx < len(node.outputs): + info[node.outputs[out_idx]] = ( + f"indexed {mode} ({buf_label}, {idx_label})" + ) + return {node: info} class FuseIndexedElemwise(GraphRewriter): - """Fuse indexed reads into Elemwise loops. + """Fuse indexed reads and indexed updates into Elemwise loops. Absorbs single-client ``AdvancedSubtensor1`` on inputs (indexed reads) + and single-client ``AdvancedIncSubtensor1`` on outputs (indexed updates) into the Elemwise iteration, avoiding intermediate arrays. Supports multiple index arrays: e.g. ``x[idx_a] + y[idx_b]`` produces - two index groups. Index arrays are shared between reads when they refer - to the same variable. + two index groups. Index arrays are shared between reads and updates + when they refer to the same variable. """ def apply(self, fgraph): @@ -134,13 +165,49 @@ def find_indexed_input_groups(fgraph, node): return [(var, axis, tuple(pos)) for var, axis, pos in groups.values()] + def find_indexed_update_consumers(fgraph, node): + """Find AdvancedIncSubtensor1 consumers of Elemwise outputs. + + Returns ``{out_idx: (update_node, target, idx_var, mode)}``. + Only considers outputs that are the value input (position 1) of + the indexed update. + """ + update_info = {} + for out_idx, out in enumerate(node.outputs): + clients = fgraph.clients[out] + # Find AdvancedIncSubtensor1 client at val position (1) + inc_clients = [ + (c, ci) + for c, ci in clients + if ci == 1 and isinstance(c.op, AdvancedIncSubtensor1) + ] + if len(inc_clients) != 1: + continue + [(client_node, _client_inp_idx)] = inc_clients + target, val, idx_var = client_node.inputs + # Don't fuse if the value broadcasts on the index loop dim + # (constant across index — recomputing per position is wasteful) + # or against non-indexed target axes. + if val.type.broadcastable[0]: + continue + val_bc = val.type.broadcastable[1:] + target_bc = target.type.broadcastable[1:] + if len(val_bc) < len(target_bc) or any( + vbc and not tbc for vbc, tbc in zip(val_bc, target_bc, strict=False) + ): + continue + mode = "set" if client_node.op.set_instead_of_inc else "inc" + update_info[out_idx] = (client_node, target, idx_var, mode) + return update_info + for node in reversed(fgraph.toposort()): if not isinstance(node.op, Elemwise): continue read_groups = find_indexed_input_groups(fgraph, node) + update_consumers = find_indexed_update_consumers(fgraph, node) - if not read_groups: + if not read_groups and not update_consumers: continue indexed_positions = { @@ -168,13 +235,19 @@ def find_indexed_input_groups(fgraph, node): reason="fuse_indexed_elemwise_reinplace", extra_protected_inputs=protected, ) + # Re-detect after inplace change + update_consumers = find_indexed_update_consumers(fgraph, node) - # Merge read index arrays into a unified ordered list + # Merge read and update index arrays into a unified ordered list all_idx_groups = {} # (idx_var, axis) -> (idx_var, position) for idx_var, _ax, _ in read_groups: key = (idx_var, _ax) if key not in all_idx_groups: all_idx_groups[key] = (idx_var, len(all_idx_groups)) + for _un, _target, idx_var, _mode in update_consumers.values(): + key = (idx_var, 0) # updates are axis 0 for now + if key not in all_idx_groups: + all_idx_groups[key] = (idx_var, len(all_idx_groups)) n_indices = len(all_idx_groups) idx_vars = [None] * n_indices @@ -184,18 +257,94 @@ def find_indexed_input_groups(fgraph, node): # Build destroy_map outer_destroy_map = {} for out_idx, inp_idx in node.op.inplace_pattern.items(): - if inp_idx not in indexed_positions: + if inp_idx not in indexed_positions and out_idx not in update_consumers: outer_destroy_map[out_idx] = [inp_idx] # Inner fgraph inputs: - # [elemwise_inputs (sources substituted)..., idx_0, ...] + # [elemwise_inputs (sources substituted)..., idx_0, ..., target_0, ...] inner_inputs = [ inp.owner.inputs[0] if i in indexed_positions else inp for i, inp in enumerate(node.inputs) ] inner_inputs = inner_inputs + idx_vars - outer_inputs = list(inner_inputs) + # If any scatter output also has other consumers, duplicate + # the elemwise output via Composite so the scatter can replace + # the duplicate while the original stays available. + multi_client_outs = set() + for out_idx in update_consumers: + update_node = update_consumers[out_idx][0] + if any( + c is not update_node + for c, _ in fgraph.clients[node.outputs[out_idx]] + ): + multi_client_outs.add(out_idx) + + if multi_client_outs: + scalar_op = node.op.scalar_op + if isinstance(scalar_op, Composite): + s_inputs = list(scalar_op.inputs) + s_outputs = list(scalar_op.outputs) + else: + scalar_node = scalar_op.make_node( + *[inp.type.to_scalar_type()() for inp in node.inputs] + ) + s_inputs = list(scalar_node.inputs) + s_outputs = list(scalar_node.outputs) + + # Wrap duplicates with identity so Composite._cleanup_graph + # doesn't clone the entire subgraph for repeated outputs. + # TODO: _cleanup_graph should use identity instead of clone + # for duplicate outputs. + dup_map = {} + for out_idx in sorted(multi_client_outs): + dup_map[out_idx] = len(s_outputs) + s_outputs.append(identity(s_outputs[out_idx])) + + new_scalar_op = Composite(s_inputs, s_outputs) + new_elemwise = Elemwise(new_scalar_op)(*node.inputs, return_list=True) + + old_node = node + node = new_elemwise[0].owner + + inner_inputs = [ + inp.owner.inputs[0] if i in indexed_positions else inp + for i, inp in enumerate(node.inputs) + ] + inner_inputs = inner_inputs + idx_vars + else: + dup_map = {} + + # Inner fgraph outputs; add update targets + inner_outputs = list(node.outputs) + call_inputs = list(inner_inputs) + for out_idx in sorted(update_consumers.keys()): + update_node, target, idx_var, _mode = update_consumers[out_idx] + + inner_inputs.append(target) + + target_pos = len(call_inputs) + if update_node.op.inplace: + call_inputs.append(target) + else: + call_inputs.append(target.copy()) + + scatter_idx = dup_map.get(out_idx, out_idx) + scatter_value = node.outputs[scatter_idx] + + if update_node.op.inplace: + scatter_out = update_node.op( + target, scatter_value, update_node.inputs[2] + ) + else: + inplace_op = AdvancedIncSubtensor1( + inplace=True, + set_instead_of_inc=update_node.op.set_instead_of_inc, + ) + scatter_out = inplace_op(target, scatter_value, idx_var) + + inner_outputs[scatter_idx] = scatter_out + outer_destroy_map[scatter_idx] = [target_pos] # Build indexed_inputs spec for the Op indexed_inputs_spec = [None] * n_indices @@ -207,7 +356,7 @@ def find_indexed_input_groups(fgraph, node): _ax, idx_var.type.broadcastable, ) - # Fill entries for index arrays with no reads + # Fill entries for index arrays with no reads (write-only) for k in range(n_indices): if indexed_inputs_spec[k] is None: for (iv, ax), (v, p) in all_idx_groups.items(): @@ -215,15 +364,46 @@ def find_indexed_input_groups(fgraph, node): indexed_inputs_spec[k] = ((), ax, iv.type.broadcastable) break + # Build indexed_outputs spec for the Op + indexed_outputs_spec = [None] * n_indices + for out_idx in sorted(update_consumers.keys()): + _update_node, _target, idx_var, mode = update_consumers[out_idx] + key = (idx_var, 0) + idx_pos = all_idx_groups[key][1] + scatter_idx = dup_map.get(out_idx, out_idx) + if indexed_outputs_spec[idx_pos] is None: + indexed_outputs_spec[idx_pos] = ([scatter_idx], mode, 0) + else: + indexed_outputs_spec[idx_pos][0].append(scatter_idx) + indexed_outputs_spec = tuple( + (tuple(e[0]), e[1], e[2]) if e is not None else None + for e in indexed_outputs_spec + ) + new_outs = IndexedElemwise( inner_inputs, - node.outputs, + inner_outputs, destroy_map=outer_destroy_map, indexed_inputs=tuple(indexed_inputs_spec), - )(*outer_inputs, return_list=True) + indexed_outputs=indexed_outputs_spec, + )(*call_inputs, return_list=True) + + orig_node = old_node if multi_client_outs else node + replacements = [] + for out_idx in range(len(orig_node.outputs)): + if out_idx in update_consumers: + update_node = update_consumers[out_idx][0] + scatter_idx = dup_map.get(out_idx, out_idx) + replacements.append((update_node.outputs[0], new_outs[scatter_idx])) + if out_idx in dup_map: + replacements.append( + (orig_node.outputs[out_idx], new_outs[out_idx]) + ) + else: + replacements.append((orig_node.outputs[out_idx], new_outs[out_idx])) fgraph.replace_all_validate( - list(zip(node.outputs, new_outs)), + replacements, reason="fuse_indexed_into_elemwise", ) diff --git a/tests/benchmarks/test_gather_fusion.py b/tests/benchmarks/test_gather_fusion.py index 04525a292d..9c545d08bf 100644 --- a/tests/benchmarks/test_gather_fusion.py +++ b/tests/benchmarks/test_gather_fusion.py @@ -1,7 +1,8 @@ -"""Micro-benchmarks for Elemwise fusion with indexed reads. +"""Micro-benchmarks for Elemwise fusion with indexed reads and updates. -Tests the benefit of fusing AdvancedSubtensor1 (indexed reads) into Elemwise -loops, avoiding materialization of intermediate arrays. +Tests the benefit of fusing AdvancedSubtensor1 (indexed reads) and +AdvancedIncSubtensor1 (indexed updates) into Elemwise loops, avoiding +materialization of intermediate arrays. """ import numpy as np @@ -12,7 +13,7 @@ from pytensor import config from pytensor.compile.mode import get_mode from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise -from pytensor.tensor.subtensor import advanced_subtensor1 +from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 @pytest.fixture( @@ -73,3 +74,79 @@ def test_gather_fusion_unfused(gather_benchmark_setup, benchmark): _, fn_unfused, vals = gather_benchmark_setup fn_unfused(*vals) # warmup benchmark(fn_unfused, *vals) + + +# --------------------------------------------------------------------------- +# Scatter output fusion benchmarks +# --------------------------------------------------------------------------- + + +@pytest.fixture( + params=[ + (85, 919, 2, 6, "inc"), # radon-like: small, inc mode + (85, 919, 2, 6, "set"), # radon-like: small, set mode + (1000, 100_000, 2, 4, "inc"), # medium, inc mode + ], + ids=["small-85x919-inc", "small-85x919-set", "medium-1Kx100K-inc"], +) +def scatter_benchmark_setup(request): + n_bins, n_data, n_gathered, n_direct, mode = request.param + + rng = np.random.default_rng(42) + idx = rng.integers(n_bins, size=n_data).astype(np.int64) + idx.sort() + + sources = [pt.vector(f"src_{i}", shape=(n_bins,)) for i in range(n_gathered)] + directs = [pt.vector(f"dir_{i}", shape=(n_data,)) for i in range(n_direct)] + target = pt.vector("target", shape=(n_bins,)) + + terms = [advanced_subtensor1(s, idx) for s in sources] + directs + elemwise_out = terms[0] + for t in terms[1:]: + elemwise_out = elemwise_out + t + + if mode == "inc": + out = pt.inc_subtensor(target[idx], elemwise_out) + else: + out = pt.set_subtensor(target[idx], elemwise_out) + + inputs = sources + directs + [target] + numba_mode = get_mode("NUMBA") + + fn_fused = pytensor.function(inputs, out, mode=numba_mode, trust_input=True) + fn_unfused = pytensor.function( + inputs, + out, + mode=numba_mode.excluding("fuse_indexed_into_elemwise"), + trust_input=True, + ) + + assert any( + isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "IndexedElemwise not found in fused graph" + assert not any( + isinstance(n.op, AdvancedIncSubtensor1) + for n in fn_fused.maker.fgraph.toposort() + ), "AdvancedIncSubtensor1 still present in fused graph" + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "IndexedElemwise found in unfused graph" + + rng = np.random.default_rng(1) + vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] + + np.testing.assert_allclose(fn_fused(*vals), fn_unfused(*vals), rtol=1e-10) + + return fn_fused, fn_unfused, vals + + +def test_scatter_fusion_fused(scatter_benchmark_setup, benchmark): + fn_fused, _, vals = scatter_benchmark_setup + fn_fused(*vals) # warmup + benchmark(fn_fused, *vals) + + +def test_scatter_fusion_unfused(scatter_benchmark_setup, benchmark): + _, fn_unfused, vals = scatter_benchmark_setup + fn_unfused(*vals) # warmup + benchmark(fn_unfused, *vals) diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index f5ea8bea15..049062d118 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -1,13 +1,18 @@ -"""Tests for IndexedElemwise fusion (indexed reads in Elemwise loops).""" +"""Tests for IndexedElemwise fusion (indexed reads and updates in Elemwise loops).""" import numpy as np import pytest import pytensor import pytensor.tensor as pt +from pytensor import config from pytensor.compile.mode import get_mode from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise -from pytensor.tensor.subtensor import advanced_subtensor1 +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor1, + AdvancedSubtensor, + advanced_subtensor1, +) numba = pytest.importorskip("numba") @@ -30,8 +35,13 @@ def assert_fused(fn): ) +# ============================================================ +# Correctness tests — indexed reads +# ============================================================ + + class TestIndexedReadFusion: - """Test indexed reads (AdvancedSubtensor1) fused into Elemwise.""" + """Test indexed reads (AdvancedSubtensor1 / AdvancedSubtensor) fused into Elemwise.""" def test_single_index_axis0(self): rng = np.random.default_rng(42) @@ -43,6 +53,42 @@ def test_single_index_axis0(self): xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + def test_single_index_axis1(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.matrix("x", shape=(3, 85)) + y = pt.matrix("y", shape=(3, 919)) + fn, fn_u = fused_and_unfused([x, y], x[:, idx] + y) + assert_fused(fn) + xv, yv = rng.normal(size=(3, 85)), rng.normal(size=(3, 919)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_multi_index_2d(self): + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused([x, ir, ic, y], x[ir, ic] + y) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + rv, cv = rng.integers(100, size=50), rng.integers(200, size=50) + yv = rng.normal(size=(50,)) + np.testing.assert_allclose(fn(xv, rv, cv, yv), fn_u(xv, rv, cv, yv), rtol=1e-10) + + def test_multi_index_3d_trailing_dim(self): + rng = np.random.default_rng(42) + x3 = pt.tensor3("x3", shape=(100, 200, 5)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + z = pt.matrix("z", shape=(50, 5)) + fn, fn_u = fused_and_unfused([x3, ir, ic, z], x3[ir, ic] + z) + assert_fused(fn) + xv = rng.normal(size=(100, 200, 5)) + rv, cv = rng.integers(100, size=50), rng.integers(200, size=50) + zv = rng.normal(size=(50, 5)) + np.testing.assert_allclose(fn(xv, rv, cv, zv), fn_u(xv, rv, cv, zv), rtol=1e-10) + def test_multiple_gathered_sources(self): rng = np.random.default_rng(42) idx = rng.integers(85, size=919).astype(np.int64) @@ -72,6 +118,133 @@ def test_broadcast_index_axis0(self): yv = np.ones(50) np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + def test_broadcast_index_axis1(self): + """Static shape=(1,) index on axis 1 broadcasts against larger direct input.""" + x = pt.matrix("x", shape=(3, 100)) + y = pt.matrix("y", shape=(3, 50)) + idx = np.array([5], dtype=np.int64) + out = x[:, idx] + y # x[:, idx] has shape (3, 1), broadcasts to (3, 50) + fn, fn_u = fused_and_unfused([x, y], out) + assert_fused(fn) + xv = np.arange(300.0).reshape(3, 100) + yv = np.ones((3, 50)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_nd_index_axis0(self): + """2D matrix index on axis 0.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx, y], pt.exp(x[mat_idx]) + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, iv, yv), fn_u(xv, iv, yv), rtol=1e-10) + + def test_nd_index_axis1(self): + """2D matrix index on axis 1 (via undo_take_reshape_for_fusion).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(3, 100)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[:, mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(3, 100)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_with_trailing_dims(self): + """2D index on axis 0 with trailing source dims.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 7)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(100, 7)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_broadcast(self): + """Broadcastable 2D index (shape (1, C)) broadcasts against direct input.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + bc_idx = pt.matrix("bc_idx", dtype="int64", shape=(1, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, bc_idx, y], x[bc_idx] + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + bcv = rng.integers(100, size=(1, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, bcv, yv), fn_u(xv, bcv, yv), rtol=1e-10) + + def test_scalar_and_vector_index(self): + """x[scalar_idx, vector_idx] — 0-d index broadcasts with 1-d index.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, y], x[scalar_idx, vec_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, sv, vv, yv), fn_u(xv, sv, vv, yv), rtol=1e-10) + + def test_vector_and_scalar_index(self): + """x[vector_idx, scalar_idx] — reversed order.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, vec_idx, scalar_idx, y], x[vec_idx, scalar_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + vv = rng.integers(100, size=50).astype(np.int64) + sv = np.array(rng.integers(200), dtype=np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, vv, sv, yv), fn_u(xv, vv, sv, yv), rtol=1e-10) + + def test_scalar_and_vector_index_with_trailing_dim(self): + """x[scalar_idx, vector_idx] on a 3-d tensor with trailing dim.""" + rng = np.random.default_rng(42) + x = pt.tensor3("x", shape=(100, 200, 7)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + z = pt.matrix("z", shape=(50, 7)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, z], x[scalar_idx, vec_idx] + z + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200, 7)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + zv = rng.normal(size=(50, 7)) + np.testing.assert_allclose(fn(xv, sv, vv, zv), fn_u(xv, sv, vv, zv), rtol=1e-10) + + def test_all_scalar_indices(self): + """AdvancedSubtensor with all 0-d indices (degenerate case).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + si0 = pt.scalar("si0", dtype="int64") + si1 = pt.scalar("si1", dtype="int64") + y = pt.scalar("y", dtype="float64") + adv_sub = AdvancedSubtensor(idx_list=(0, 1))(x, si0, si1) + fn, fn_u = fused_and_unfused([x, si0, si1, y], adv_sub + y) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + s0 = np.array(rng.integers(100), dtype=np.int64) + s1 = np.array(rng.integers(200), dtype=np.int64) + yv = np.array(rng.normal()) + np.testing.assert_allclose(fn(xv, s0, s1, yv), fn_u(xv, s0, s1, yv), rtol=1e-10) + def test_negative_indices(self): """Negative indices must be handled correctly (sign-extended, not zero-extended).""" rng = np.random.default_rng(42) @@ -88,6 +261,220 @@ def test_negative_indices(self): np.testing.assert_allclose(fn(xv, idxv, yv), fn_u(xv, idxv, yv), rtol=1e-10) +# ============================================================ +# Correctness tests — indexed updates +# ============================================================ + + +class TestIndexedUpdateFusion: + """Test indexed updates (AdvancedIncSubtensor1) fused into Elemwise.""" + + def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): + """Don't fuse if the indexed axes are not within the Elemwise loop. + + Here the index is on axis 0 of target(5, 10), but the Elemwise + output (10,) corresponds to axis 1 (the non-indexed trailing axis). + The indexed axis doesn't overlap with the Elemwise computation, so + fusing would misalign which input dims map to which target dims. + """ + rng = np.random.default_rng(42) + idx = rng.integers(5, size=10).astype(np.int64) + target = pt.matrix("target", shape=(5, 10)) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(10,)) + elemwise_out = advanced_subtensor1(x, idx) + y # shape (10,) + out = pt.inc_subtensor(target[idx], elemwise_out) + fn, fn_u = fused_and_unfused([x, y, target], out) + # Write not fused — the Elemwise loop dim is the non-indexed axis, + # not the indexed axis. Read fusion may still create an + # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. + assert any( + isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() + ) + xv = rng.normal(size=(85,)) + yv = rng.normal(size=(10,)) + tv = rng.normal(size=(5, 10)) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_no_fusion_when_val_broadcasts_along_index_dim(self): + """Don't fuse if val is broadcastable on the index loop dim. + + When val is constant across the index (e.g. shape (1,) with idx + of size 10), fusing would recompute the same Elemwise result at + every index position. Better to compute once and scatter. + """ + rng = np.random.default_rng(42) + idx = rng.integers(5, size=10).astype(np.int64) + target = pt.vector("target", shape=(5,)) + x = pt.tensor("x", shape=(1,)) + out = pt.inc_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([x, target], out) + # Write not fused — val broadcasts along the index loop dim. + assert any( + isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() + ) + xv = rng.normal(size=(1,)) + tv = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, tv), fn_u(xv, tv), rtol=1e-10) + + def test_broadcast_val_into_non_indexed_dims(self): + """Fuse when Elemwise output broadcasts into target's non-indexed dims. + + target(5, 3)[:, idx] += exp(val) — the Elemwise loop covers the + index axis (axis 1), and the scalar result broadcasts into the + non-indexed leading axis (axis 0, core_ndim=1). + + Requires excluding local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 + so we keep AdvancedIncSubtensor on the correct axis. + """ + rng = np.random.default_rng(42) + idx = rng.integers(3, size=10).astype(np.int64) + target = pt.matrix("target", shape=(5, 3)) + val = pt.vector("val", shape=(10,)) + out = pt.inc_subtensor(target[:, idx], pt.exp(val)) + + mode = NUMBA_MODE.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + mode_u = NUMBA_NO_FUSION.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + fn = pytensor.function([val, target], out, mode=mode, trust_input=True) + fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True) + assert_fused(fn) + valv = rng.normal(size=(10,)) + tv = rng.normal(size=(5, 3)) + np.testing.assert_allclose(fn(valv, tv), fn_u(valv, tv), rtol=1e-10) + + def test_inc_subtensor(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.inc_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn, fn_u = fused_and_unfused([x, y, t], out) + assert_fused(fn) + xv, yv, tv = ( + rng.normal(size=(85,)), + rng.normal(size=(919,)), + rng.normal(size=(85,)), + ) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_set_subtensor(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.set_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn, fn_u = fused_and_unfused([x, y, t], out) + assert_fused(fn) + xv, yv, tv = ( + rng.normal(size=(85,)), + rng.normal(size=(919,)), + rng.normal(size=(85,)), + ) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_target_not_modified_when_non_inplace(self): + """Non-inplace scatter should not modify the original target.""" + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.inc_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn = pytensor.function([x, y, t], out, mode=NUMBA_MODE, trust_input=True) + xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) + tv = rng.normal(size=(85,)) + tv_copy = tv.copy() + fn(xv, yv, tv) + np.testing.assert_array_equal(tv, tv_copy) + + def test_multi_index_inc_subtensor(self): + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_multi_index_set_subtensor(self): + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.set_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_multi_index_write_with_trailing_dims(self): + rng = np.random.default_rng(42) + target = pt.tensor3("target", shape=(100, 200, 5)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.matrix("x", shape=(50, 5)) + out = pt.inc_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200, 5)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=(50, 5)) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_combined_multi_index_read_write(self): + """Read and write share the same multi-index arrays.""" + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + source = pt.matrix("source", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[ir, ic], source[ir, ic] + x) + fn, fn_u = fused_and_unfused([target, source, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + sv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose( + fn(tv, sv, rv, cv, xv), fn_u(tv, sv, rv, cv, xv), rtol=1e-10 + ) + + def test_scalar_and_vector_index_inc(self): + """inc_subtensor with x[scalar_idx, vector_idx] — 0-d and 1-d indices.""" + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[scalar_idx, vec_idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, scalar_idx, vec_idx, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, sv, vv, xv), fn_u(tv, sv, vv, xv), rtol=1e-10) + + class TestShapeValidation: """Test that mismatched index/input shapes raise runtime errors. @@ -109,6 +496,56 @@ def test_mismatched_index_and_direct_input(self): with pytest.raises(Exception): fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(49)) + def test_mismatched_multi_index_lengths(self): + """Two index arrays in a multi-index have different lengths.""" + x = pt.matrix("x", shape=(None, None)) + ir = pt.vector("ir", dtype="int64", shape=(None,)) + ic = pt.vector("ic", dtype="int64", shape=(None,)) + y = pt.vector("y", shape=(None,)) + out = x[ir, ic] + y + fn = pytensor.function([x, ir, ic, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: all 50 — should work + fn( + np.zeros((100, 200)), + np.zeros(50, dtype=np.int64), + np.zeros(50, dtype=np.int64), + np.zeros(50), + ) + # Mismatched: ir=50, ic=49 — should error + with pytest.raises(Exception): + fn( + np.zeros((100, 200)), + np.zeros(50, dtype=np.int64), + np.zeros(49, dtype=np.int64), + np.zeros(50), + ) + + def test_mismatched_index_vs_direct_on_non_indexed_axis(self): + """Index and direct input disagree on a non-indexed (trailing) axis.""" + x = pt.tensor3("x", shape=(None, None, None)) + ir = pt.vector("ir", dtype="int64", shape=(None,)) + ic = pt.vector("ic", dtype="int64", shape=(None,)) + z = pt.matrix("z", shape=(None, None)) + out = x[ir, ic] + z # result shape (N, trailing) + fn = pytensor.function([x, ir, ic, z], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: trailing dim = 5 for both + fn( + np.zeros((10, 20, 5)), + np.zeros(3, dtype=np.int64), + np.zeros(3, dtype=np.int64), + np.zeros((3, 5)), + ) + # Mismatched: x trailing=5, z trailing=4 + with pytest.raises(Exception): + fn( + np.zeros((10, 20, 5)), + np.zeros(3, dtype=np.int64), + np.zeros(3, dtype=np.int64), + np.zeros((3, 4)), + ) + def test_runtime_broadcast_on_index_dim(self): """Symbolic shapes that happen to be 1 at runtime — broadcast check.""" x = pt.vector("x", shape=(None,)) @@ -123,3 +560,151 @@ def test_runtime_broadcast_on_index_dim(self): # idx=1, y=5 — should error (shape mismatch, no static broadcast info) with pytest.raises(Exception): fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) + + def test_shared_read_write_index_no_broadcast(self): + """A shared read+write index must not broadcast at runtime.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + t = pt.vector("t", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = pt.inc_subtensor(t[idx], x[idx] + y) + fn = pytensor.function([x, y, t, idx], out, mode=NUMBA_MODE, trust_input=True) + fn(np.zeros(100), np.zeros(50), np.zeros(100), np.arange(50, dtype=np.int64)) + with pytest.raises(Exception): + fn( + np.zeros(100), + np.zeros(50), + np.zeros(100), + np.array([0], dtype=np.int64), + ) + + def test_mismatched_nd_index_dims(self): + """ND index shape doesn't match direct input shape.""" + x = pt.vector("x", shape=(None,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(None, None)) + y = pt.matrix("y", shape=(None, None)) + out = x[mat_idx] + y + fn = pytensor.function([x, mat_idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: mat_idx=(3,4), y=(3,4) — should work + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 4))) + # Mismatched: mat_idx=(3,4), y=(3,5) — should error + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 5))) + + # ============================================================ + # Radon model integration test + # ============================================================ + + def test_write_only_index_no_broadcast(self): + """A write-only index must not broadcast against the loop. + + target[idx] += exp(y) where idx is write-only (not used for reads). + If idx has shape (1,) at runtime but y has shape (50,), the loop + runs 50 iterations. The index must not silently broadcast — it + should error because the shapes don't match. + """ + rng = np.random.default_rng(42) + y = pt.vector("y", shape=(None,)) + t = pt.vector("t", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = pt.inc_subtensor(t[idx], pt.exp(y)) + fn = pytensor.function([y, t, idx], out, mode=NUMBA_MODE, trust_input=True) + fn_u = pytensor.function( + [y, t, idx], out, mode=NUMBA_NO_FUSION, trust_input=True + ) + assert_fused(fn) + # Matching shapes — should work + yv = rng.normal(size=50) + tv = rng.normal(size=100) + iv = rng.integers(100, size=50).astype(np.int64) + np.testing.assert_allclose(fn(yv, tv, iv), fn_u(yv, tv, iv), rtol=1e-10) + # Mismatched: idx=1, y=50 — should error + with pytest.raises(Exception): + fn(np.zeros(50), np.zeros(100), np.array([0], dtype=np.int64)) + + +class TestRepeatedAccumulationIndices: + """Test inc_subtensor with repeated indices (same position accumulated multiple times).""" + + def test_repeated_inc_subtensor(self): + """inc_subtensor with repeated indices should accumulate correctly.""" + rng = np.random.default_rng(42) + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(20,)) + # Indices with repeats: multiple values go to the same target position + idx = np.array( + [0, 1, 2, 0, 1, 2, 3, 3, 3, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3], + dtype=np.int64, + ) + out = pt.inc_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, x], out) + assert_fused(fn) + tv = rng.normal(size=10) + xv = rng.normal(size=20) + np.testing.assert_allclose(fn(tv, xv), fn_u(tv, xv), rtol=1e-10) + + def test_repeated_set_subtensor(self): + """set_subtensor with repeated indices -- last write wins.""" + rng = np.random.default_rng(42) + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 0, 1, 1, 2], dtype=np.int64) + out = pt.set_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, x], out) + assert_fused(fn) + tv = rng.normal(size=10) + xv = rng.normal(size=5) + np.testing.assert_allclose(fn(tv, xv), fn_u(tv, xv), rtol=1e-10) + + def test_repeated_inc_with_read(self): + """Combined read + inc_subtensor with repeated indices.""" + rng = np.random.default_rng(42) + source = pt.vector("source", shape=(10,)) + target = pt.vector("target", shape=(10,)) + idx = np.array([0, 1, 0, 1, 2, 2, 3, 3], dtype=np.int64) + out = pt.inc_subtensor(target[idx], source[idx] * 2.0) + fn, fn_u = fused_and_unfused([source, target], out) + assert_fused(fn) + sv = rng.normal(size=10) + tv = rng.normal(size=10) + np.testing.assert_allclose(fn(sv, tv), fn_u(sv, tv), rtol=1e-10) + + +class TestRadonModel: + """Test fusion on the radon hierarchical model (logp + gradient).""" + + def test_radon_model_correctness(self): + import sys + + sys.path.insert(0, "tests/benchmarks") + from test_compilation import create_radon_model + + joined_inputs, [model_logp, model_dlogp] = create_radon_model() + fn, fn_u = fused_and_unfused([joined_inputs], [model_logp, model_dlogp]) + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + results_fused = fn(x) + results_unfused = fn_u(x) + for i, (rf, ru) in enumerate(zip(results_fused, results_unfused)): + np.testing.assert_allclose(rf, ru, rtol=1e-6, err_msg=f"Output {i}") + + def test_radon_model_no_unfused_indexing(self): + """After fusion, no AdvancedSubtensor1 or AdvancedIncSubtensor1 should remain.""" + import sys + + sys.path.insert(0, "tests/benchmarks") + from test_compilation import create_radon_model + + from pytensor.tensor.subtensor import AdvancedSubtensor1 + + joined_inputs, [model_logp, model_dlogp] = create_radon_model() + fn = pytensor.function( + [joined_inputs], + [model_logp, model_dlogp], + mode=NUMBA_MODE, + trust_input=True, + ) + nodes = fn.maker.fgraph.toposort() + assert not any(isinstance(n.op, AdvancedSubtensor1) for n in nodes) + assert not any(isinstance(n.op, AdvancedIncSubtensor1) for n in nodes) From 3a81c2fae0b6d0ff679fb8464914ae9273b2bf64 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 1 Apr 2026 15:05:18 +0200 Subject: [PATCH 4/7] Numba: generalize IndexedElemwise to arbitrary-axis and multi-index Support AdvancedSubtensor on any axis (not just axis 0) and multi-index patterns like x[idx_row, idx_col] where multiple 1D index arrays address consecutive source axes. Generalize writes (AdvancedIncSubtensor) to match. Reads: - Add undo_take_dimshuffle_for_fusion pre-fusion rewrite - _get_indexed_read_info handles AdvancedSubtensor with consecutive tensor indices, full-slice prefix/suffix - Reject boolean indices and non-consecutive advanced indices Writes: - _get_indexed_update_info mirrors _get_indexed_read_info for AdvancedIncSubtensor - find_indexed_update_consumers detects both AdvancedIncSubtensor1 and AdvancedIncSubtensor - Broadcast guard generalized for non-axis-0 indexed axes - Scatter construction supports AdvancedIncSubtensor (inplace) Dispatch + codegen: - indexed_inputs encoding: ((positions, axis, idx_bc), ...) - input_read_spec uses tuple of (idx_k, axis) pairs per input - n_index_loop_dims = max(idx.ndim for group) --- pytensor/link/numba/dispatch/elemwise.py | 13 +- .../link/numba/dispatch/vectorize_codegen.py | 258 ++++++++++++++---- pytensor/tensor/rewriting/indexed_elemwise.py | 241 +++++++++++++--- tests/link/numba/test_indexed_elemwise.py | 195 +------------ 4 files changed, 424 insertions(+), 283 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 17f26b4860..f808f99f87 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -435,16 +435,21 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): nin_elemwise = len(elemwise_node.inputs) nout = len(elemwise_node.outputs) - core_op_fn = store_core_outputs( scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs ) # --- Broadcast and type encodings ------------------------------------- + # Use each input's result broadcastable (not the source's). + # For indexed reads, the result has the index dim(s) substituted. + # For multi-index, the result has fewer dims than the source. input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) - output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) + # Use Elemwise output bc (loop dims) for ALL outputs — including updates. + # The fix-up block in _vectorized corrects the actual output/return types + # for update outputs to match the target buffer. + output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs) - # Filter out inplace entries for update outputs (handled by scatter) + # Filter out elemwise inplace for indexed-update outputs (they use update targets) inplace_pattern = tuple( (out_idx, inp_idx) for out_idx, inp_idx in elemwise_node.op.inplace_pattern.items() @@ -484,7 +489,7 @@ def impl(*outer_inputs): return impl - cache_version = (0, 2) + cache_version = (0, 3) if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index a4d85d93b4..81b858cec3 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -373,12 +373,12 @@ def make_loop_call( input_types: tuple[Any, ...], output_types: tuple[Any, ...], core_scalar: bool = True, - input_read_spec: tuple[tuple[int, int] | None, ...] | None = None, + input_read_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, idx_arrs: list | None = None, idx_types: tuple | None = None, - idx_load_axis: tuple[int, ...] | None = None, + idx_load_axes: tuple[tuple[int, ...], ...] | None = None, idx_bc: tuple[tuple[bool, ...], ...] | None = None, - output_update_spec: tuple[tuple[int, int] | None, ...] | None = None, + output_update_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, ): safe = (False, False) @@ -431,15 +431,17 @@ def _wrap_negative_index(idx_val, dim_size, signed): idxs = [loopval.index for loopval in loops] # Load indirect indices for all index arrays. - # Each index array is 1D and is accessed by the loop counter for its axis. + # Each index array may be ND (e.g. a 2D matrix index), accessed by + # multiple loop counters corresponding to its load axes. indirect_idxs = [] - if idx_arrs is not None and idx_load_axis is not None: - for gi_k, (gi_arr, gi_type, ax) in enumerate( - zip(idx_arrs, idx_types, idx_load_axis) + if idx_arrs is not None and idx_types is not None and idx_load_axes is not None: + for gi_k, (gi_arr, gi_type, load_axes) in enumerate( + zip(idx_arrs, idx_types, idx_load_axes) ): - # Use zero if the index is statically broadcastable, loop counter otherwise - idx_is_bc = idx_bc[gi_k][0] if idx_bc and idx_bc[gi_k] else False - load_idx = zero if idx_is_bc else idxs[ax] + load_idxs = [] + for d, ax in enumerate(load_axes): + is_bc = idx_bc[gi_k][d] if idx_bc and len(idx_bc[gi_k]) > d else False + load_idxs.append(zero if is_bc else idxs[ax]) gi_ptr = cgutils.get_item_pointer2( context, builder, @@ -447,7 +449,7 @@ def _wrap_negative_index(idx_val, dim_size, signed): cgutils.unpack_tuple(builder, gi_arr.shape), cgutils.unpack_tuple(builder, gi_arr.strides), gi_type.layout, - [load_idx], + load_idxs, False, False, ) @@ -470,14 +472,27 @@ def _wrap_negative_index(idx_val, dim_size, signed): zip(inputs, input_types, input_bc, strict=True) ): spec = input_read_spec[input_i] if input_read_spec is not None else None - core_ndim = input_type.ndim - len(bc) + n_indexed = len(spec) if spec else 0 + # n_indexed source axes are replaced by the index arrays' broadcast + # loop dims. n_index_loop_dims = max ndim of the index arrays in the + # group (1 for 1D vectors, 2 for 2D matrices, etc.). + n_index_loop_dims = ( + max((idx_types[idx_k].ndim for idx_k, _ in spec), default=0) if spec else 0 # type: ignore[index] + ) + core_ndim = input_type.ndim - len(bc) - n_indexed + n_index_loop_dims if spec is not None: assert idx_types is not None - # Single-index on one axis: replace that axis with the indirect index indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} input_shape = cgutils.unpack_tuple(builder, input.shape) idxs_bc = [] + # Result dims correspond to: [indexed_loop_dim(s), then non-indexed source axes]. + # For 1D multi-index on consecutive axes 0..n-1: + # result dim 0 = indexed loop dim + # result dim 1 = source dim n, etc. + # For ND single-index on axis A (e.g. 2D mat_idx): + # result dims A..A+D-1 = index loop dims + # remaining = non-indexed source axes result_dim = 0 for src_dim in range(input_type.ndim): if src_dim in indexed_axes: @@ -488,7 +503,10 @@ def _wrap_negative_index(idx_val, dim_size, signed): signed=idx_types[idx_k].dtype.signed, ) idxs_bc.append(idx_val) - result_dim += 1 + if n_indexed == 1: + result_dim += n_index_loop_dims + elif src_dim == max(indexed_axes): + result_dim += n_index_loop_dims else: if result_dim < len(bc): idxs_bc.append(zero if bc[result_dim] else idxs[result_dim]) @@ -556,10 +574,18 @@ def _wrap_negative_index(idx_val, dim_size, signed): spec = output_update_spec[output_i] if output_update_spec is not None else None if spec is not None: assert idx_types is not None - # Indexed-update output: same logic as indexed-read input + # Indexed-update output: same logic as indexed-read input. + # Recompute core_ndim from the target's actual dims since + # output_bc may not match (it's the Elemwise output bc). indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} n_indexed = len(indexed_axes) - source_batch_ndim = len(bc) + n_indexed - 1 + n_index_loop_dims = max( + (idx_types[idx_k].ndim for idx_k, _ in spec), default=0 + ) + # Number of source (target) batch dims + source_batch_ndim = len(bc) + n_indexed - n_index_loop_dims + core_ndim = output_type.ndim - source_batch_ndim + max_indexed_axis = max(indexed_axes) idxs_bc = [] loop_dim = 0 for src_dim in range(source_batch_ndim): @@ -571,6 +597,8 @@ def _wrap_negative_index(idx_val, dim_size, signed): signed=idx_types[idx_k].dtype.signed, ) idxs_bc.append(idx_val) + if src_dim >= max_indexed_axis: + loop_dim += n_index_loop_dims else: bc_dim = bc[loop_dim] if loop_dim < len(bc) else False idxs_bc.append(zero if bc_dim else idxs[loop_dim]) @@ -691,25 +719,31 @@ def _vectorized( raise TypingError("allow_core_scalar must be literal.") allow_core_scalar = allow_core_scalar.literal_value - # Count scatter targets (one per scattered output) - n_update_targets = sum( - len(entry[0]) for entry in indexed_outputs if entry is not None - ) + # Count scatter targets (one per unique output index) + _update_out_idxs = set() + for entry in indexed_outputs: + if entry is not None: + _update_out_idxs.update(entry[0]) + n_update_targets = len(_update_out_idxs) n_indices = len(indexed_inputs) n_elemwise = len(outer_input_types) - n_indices - n_update_targets - input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) + source_input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) idx_types = tuple(outer_input_types[n_elemwise + k] for k in range(n_indices)) update_target_types = tuple( outer_input_types[n_elemwise + n_indices + j] for j in range(n_update_targets) ) # indexed_inputs entries are (positions, axis) — one per index array. - # We aggregate per-input into a tuple of (idx_k, src_axis) pairs. + # For multi-index (e.g. x[idx_row, idx_col]), an input appears in multiple + # entries with different axes. We aggregate per-input into a tuple of + # (idx_k, src_axis) pairs. # - # idx_load_axis[k] = which loop dim loads index array k. + # idx_load_axes[k] = tuple of loop dims used to load index array k. + # For a 1-D index on axis A this is (A,); for a 2-D index it is (A, A+1), etc. + # For multi-index on consecutive axes, the group's min_axis is the start. _read_spec_dict: dict[int, list[tuple[int, int]]] = {} - idx_load_axis = [] + idx_load_axes = [] idx_bc_list = [] # per index array: broadcastable tuple for k, entry in enumerate(indexed_inputs): positions, axis = entry[0], entry[1] @@ -717,15 +751,57 @@ def _vectorized( idx_bc_list.append(idx_bc) for p in positions: _read_spec_dict.setdefault(p, []).append((k, axis)) - # idx_load_axis[k] = which loop dim to use when loading index array k. + # Build write-side grouping: index arrays that update the same output + # share a group and should use the same min_axis. + _write_group: dict[int, list[tuple[int, int]]] = {} # out_idx -> [(k, axis)] + for k, entry in enumerate(indexed_outputs): + if entry is None: + continue + output_indices, _mode, axis = entry + for out_idx in output_indices: + _write_group.setdefault(out_idx, []).append((k, axis)) + + # idx_load_axes[k] = tuple of loop dims for loading index array k. for k, entry in enumerate(indexed_inputs): _positions, axis = entry[0], entry[1] - idx_load_axis.append(axis) + idx_ndim = idx_types[k].ndim + # Find the group's min_axis from reads and writes + min_axis = axis + for p in _positions: + if p in _read_spec_dict: + for _other_k, other_axis in _read_spec_dict[p]: + min_axis = min(min_axis, other_axis) + for out_idx, group in _write_group.items(): + if any(gk == k for gk, _ in group): + for _other_k, other_axis in group: + min_axis = min(min_axis, other_axis) + idx_load_axes.append(tuple(range(min_axis, min_axis + idx_ndim))) input_read_spec = tuple( tuple(_read_spec_dict[p]) if p in _read_spec_dict else None for p in range(n_elemwise) ) - idx_load_axis = tuple(idx_load_axis) + idx_load_axes = tuple(idx_load_axes) + + # Build effective input types that match input_bc_patterns ndim. + # For indexed inputs, the source ndim differs from the result ndim: + # - Multi-index (N 1-D indices on N axes): collapses N source axes into 1 loop dim + # - ND index (1 index with ndim D on 1 axis): expands 1 source axis into D loop dims + input_types = [] + for p, src_type in enumerate(source_input_types): + spec = input_read_spec[p] + if spec is not None: + n_indexed_axes = len(spec) + n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec) + if n_indexed_axes != n_index_loop_dims: + effective_ndim = src_type.ndim - n_indexed_axes + n_index_loop_dims + input_types.append( + types.Array(src_type.dtype, effective_ndim, src_type.layout) + ) + else: + input_types.append(src_type) + else: + input_types.append(src_type) + input_types = tuple(input_types) # Per-output: tuple of (idx_k, axis) pairs, or None. # Same format as input_read_spec. @@ -740,9 +816,10 @@ def _vectorized( output_indices, _mode, axis = entry for out_idx in output_indices: _update_spec_dict.setdefault(out_idx, []).append((k, axis)) - update_out_to_target[out_idx] = target_counter + if out_idx not in update_out_to_target: + update_out_to_target[out_idx] = target_counter + target_counter += 1 update_out_indices.add(out_idx) - target_counter += 1 output_update_spec = tuple( tuple(_update_spec_dict[p]) if p in _update_spec_dict else None for p in range(len(output_bc_patterns)) @@ -767,9 +844,20 @@ def _vectorized( for out_idx, target_idx in update_out_to_target.items(): target_type = update_target_types[target_idx] out_types[out_idx] = target_type + # Core ndim = target dims minus the dims addressed by the loop. + # For multi-index or ND, the indexed axes are replaced by loop dims + # via indirect indexing, so the "batch" portion of the target is + # the number of source dims addressed by indirect + loop counters. + spec = output_update_spec[out_idx] + if spec is not None: + n_indexed = len(spec) + n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec) + effective_batch = batch_ndim + n_indexed - n_index_loop_dims + else: + effective_batch = batch_ndim core_out_types[out_idx] = types.Array( dtype=target_type.dtype, - ndim=target_type.ndim - batch_ndim, + ndim=target_type.ndim - effective_batch, layout=target_type.layout, ) out_types = tuple(out_types) @@ -791,12 +879,13 @@ def _vectorized( inplace_pattern_val = inplace_pattern input_read_spec_val = input_read_spec idx_types_val = idx_types - idx_load_axis_val = idx_load_axis + idx_load_axes_val = idx_load_axes idx_bc_list_val = idx_bc_list output_update_spec_val = output_update_spec update_out_to_target_val = update_out_to_target update_target_types_val = update_target_types update_out_indices_val = update_out_indices + indexed_outputs_val = indexed_outputs def codegen(ctx, builder, sig, args): [ @@ -824,7 +913,7 @@ def codegen(ctx, builder, sig, args): # First n_elemwise outer inputs are elemwise inputs (source arrays) inputs = [ - arrayobj.make_array(input_types[i])(ctx, builder, all_outer[i]) + arrayobj.make_array(source_input_types[i])(ctx, builder, all_outer[i]) for i in range(n_elemwise) ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] @@ -846,8 +935,11 @@ def codegen(ctx, builder, sig, args): ] # Build iter_shapes for compute_itershape. - # For indexed inputs, substitute the index array length on the - # indexed axis of the source shape. + # For indexed inputs, the source array may have more dims than the + # iteration shape (multi-index collapses multiple source axes into one + # loop dim). Replace the source shape with a constructed shape that + # matches the bc pattern: one entry per loop dim, with index lengths + # substituted for the indexed loop dim(s). iter_shapes = list(in_shapes) iter_bc = list(input_bc_patterns_val) idx_shapes = [ @@ -856,35 +948,83 @@ def codegen(ctx, builder, sig, args): for p, spec in enumerate(input_read_spec_val): if spec is None: continue - # Single-index: substitute on the indexed axis - idx_k, axis = spec[0] - iter_shapes[p] = list(iter_shapes[p]) - iter_shapes[p][axis] = idx_shapes[idx_k][0] + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + n_indexed = len(indexed_axes) + n_index_loop_dims = max(idx_types_val[idx_k].ndim for idx_k, _ in spec) + if n_indexed == n_index_loop_dims: + # Simple case (1 index on 1 axis, or N 1-D indices on N axes): + # substitute each indexed axis with the index's shape dim. + idx_k, axis = spec[0] + iter_shapes[p] = list(iter_shapes[p]) + iter_shapes[p][axis] = idx_shapes[idx_k][0] + else: + # Mismatch: ND index on fewer axes or multi-index collapsing. + # Build shape mapping result loop dims to source dims or index dims. + # Indexed source axes expand to n_index_loop_dims result dims. + source_shape = cgutils.unpack_tuple(builder, inputs[p].shape) + batch_ndim = len(input_bc_patterns_val[p]) + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + max_axis = max(a for _, a in spec) + new_shape = [] + src_d = 0 + idx_d = 0 + for loop_d in range(batch_ndim): + if src_d in indexed_axes and idx_d < n_index_loop_dims: + idx_k_first = spec[0][0] + new_shape.append(idx_shapes[idx_k_first][idx_d]) + idx_d += 1 + if idx_d >= n_index_loop_dims: + src_d = max_axis + 1 + else: + new_shape.append(source_shape[src_d]) + src_d += 1 + iter_shapes[p] = new_shape - # Each index array participates in iter_shape validation on its load - # axis, just like any direct input. - # - # Read indices use their static broadcastable so shape-1 indices - # broadcast correctly against other inputs. + # Each index array participates in iter_shape validation. # - # Write indices are forced to bc=False: unlike reads, numpy does - # not allow the index to broadcast against the update value. A - # shape-1 write index at runtime must match a size-1 loop, not - # silently repeat writes to the same target position. + # Write indices can broadcast against each other (e.g. ir=(3,1) + # and ic=(1,4) → (3,4)), so we honour their static bc. But if + # ALL write indices on a given loop dim are bc=True, none of them + # constrains the loop size and a shape-1 index would silently + # repeat writes. In that case we force bc=False so + # compute_itershape requires the index length to match the loop. batch_ndim = len(input_bc_patterns_val[0]) if input_bc_patterns_val else 0 one = ir.IntType(64)(1) + + # Per loop dim: is every write index broadcastable? + _write_all_bc = [True] * batch_ndim for k in range(n_indices): - ax = idx_load_axis_val[k] - is_write = indexed_outputs[k] is not None + if indexed_outputs_val[k] is None: + continue + load_axes = idx_load_axes_val[k] + for d, ax in enumerate(load_axes): + if ax < batch_ndim: + idx_bc_on_d = ( + idx_bc_list_val[k][d] if d < len(idx_bc_list_val[k]) else False + ) + if not idx_bc_on_d: + _write_all_bc[ax] = False + + for k in range(n_indices): + load_axes = idx_load_axes_val[k] + is_write = indexed_outputs_val[k] is not None idx_shape_entry = [one] * batch_ndim - idx_shape_entry[ax] = idx_shapes[k][0] + bc_entry = [True] * batch_ndim + for d, ax in enumerate(load_axes): + if ax < batch_ndim and d < len(idx_shapes[k]): + idx_shape_entry[ax] = idx_shapes[k][d] + idx_bc_on_d = ( + idx_bc_list_val[k][d] if d < len(idx_bc_list_val[k]) else False + ) + # Force non-bc if this is a write index and all write + # indices on this dim are bc — otherwise the loop dim + # is unconstrained by any write index. + if is_write and idx_bc_on_d and _write_all_bc[ax]: + idx_bc_on_d = False + if ax < batch_ndim: + bc_entry[ax] = idx_bc_on_d iter_shapes.append(idx_shape_entry) - idx_bc_on_ax = idx_bc_list_val[k][0] if idx_bc_list_val[k] else False - if is_write: - idx_bc_on_ax = False - iter_bc.append( - tuple(True if d != ax else idx_bc_on_ax for d in range(batch_ndim)) - ) + iter_bc.append(tuple(bc_entry)) iter_shape = compute_itershape( ctx, @@ -915,7 +1055,7 @@ def codegen(ctx, builder, sig, args): output_dtypes_val, inplace_pattern_val, inputs, - input_types, + source_input_types, output_core_shapes, update_outputs=update_outputs_dict, ) @@ -942,13 +1082,13 @@ def codegen(ctx, builder, sig, args): outputs, input_bc_patterns_val, output_bc_patterns_val, - input_types, + source_input_types, output_types, core_scalar=allow_core_scalar, input_read_spec=input_read_spec_val, idx_arrs=idx_arrs, idx_types=idx_types_val, - idx_load_axis=idx_load_axis_val, + idx_load_axes=idx_load_axes_val, idx_bc=idx_bc_list_val, output_update_spec=output_update_spec_val, ) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index d180530a16..8e23fb5bf4 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -8,18 +8,91 @@ from pytensor.compile import optdb from pytensor.compile.builders import OpFromGraph -from pytensor.graph.rewriting.basic import GraphRewriter +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import GraphRewriter, dfs_rewriter from pytensor.graph.rewriting.db import SequenceDB from pytensor.printing import op_debug_information from pytensor.scalar.basic import Composite, identity -from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, AdvancedIncSubtensor1, + AdvancedSubtensor, AdvancedSubtensor1, + indices_from_subtensor, ) +@node_rewriter([DimShuffle]) +def undo_take_dimshuffle_for_fusion(fgraph, node): + """Undo ``DimShuffle(AdvancedSubtensor1(DimShuffle(x), idx))`` -> ``AdvancedSubtensor(x, :, ..., idx, :, ...)``. + + The ``local_replace_AdvancedSubtensor`` specialize rewrite converts + ``x[:, idx]`` into ``x.T[idx].T`` (axis-swap + AdvancedSubtensor1 + + axis-swap). This rewrite undoes that when the result feeds a single + Elemwise, so ``FuseIndexedElemwise`` can absorb the indexing directly + on the correct axis. + + See also ``undo_take_reshape_for_fusion`` which handles the analogous + Reshape+flatten pattern for ND indices. + """ + # Outer DimShuffle must be an axis swap + outer_ds = node.op + if outer_ds.augment or outer_ds.drop: + return None + order = outer_ds.new_order + ndim = len(order) + if ndim < 2: + return None + + # Find the swapped axis: exactly two positions differ from identity + swapped = [i for i in range(ndim) if order[i] != i] + if len(swapped) != 2: + return None + ax_a, ax_b = swapped + if order[ax_a] != ax_b or order[ax_b] != ax_a: + return None + axis = max(ax_a, ax_b) # the non-zero axis (0 was swapped to axis) + + # Inner must be AdvancedSubtensor1 + inner = node.inputs[0] + if inner.owner is None or not isinstance(inner.owner.op, AdvancedSubtensor1): + return None + asub1_node = inner.owner + + # AdvancedSubtensor1's input must be an inverse DimShuffle (same swap) + inner_ds_var = asub1_node.inputs[0] + if inner_ds_var.owner is None or not isinstance(inner_ds_var.owner.op, DimShuffle): + return None + inner_ds = inner_ds_var.owner.op + if inner_ds.new_order != tuple(order): + return None + + # Both intermediates must be single-client + if len(fgraph.clients[inner]) != 1: + return None + if len(fgraph.clients[inner_ds_var]) != 1: + return None + + # Outer DimShuffle must be consumed only by a single Elemwise + clients = fgraph.clients[node.outputs[0]] + if len(clients) != 1: + return None + client_node, _client_idx = clients[0] + if not isinstance(getattr(client_node, "op", None), Elemwise): + return None + + # Build AdvancedSubtensor: x[:, ..., idx, :, ...] + source = inner_ds_var.owner.inputs[0] + idx_var = asub1_node.inputs[1] + + idx_list = [slice(None)] * ndim + idx_list[axis] = 0 # pointer to the single index variable + new_out = AdvancedSubtensor(idx_list=idx_list)(source, idx_var) + return [new_out] + + indexed_elemwise_optdb = SequenceDB() optdb.register( "fuse_indexed_into_elemwise", @@ -30,6 +103,13 @@ position=100, ) +indexed_elemwise_optdb.register( + "undo_take_dimshuffle_for_fusion", + dfs_rewriter(undo_take_dimshuffle_for_fusion), + "numba", + position=0, +) + class IndexedElemwise(OpFromGraph): """Fuse indexed reads and updates into a single Elemwise iteration loop. @@ -135,19 +215,44 @@ def _get_indexed_read_info(var): Returns ``(source, [(idx_var, axis), ...])`` or ``None``. Handles: - ``AdvancedSubtensor1(source, idx)`` -> single index on axis 0 + - ``AdvancedSubtensor(source, :, ..., idx, :, ...)`` -> single or + multi-index with tensor indices on consecutive axes, + followed only by full slices. """ if var.owner is None: return None op = var.owner.op if isinstance(op, AdvancedSubtensor1): return (var.owner.inputs[0], [(var.owner.inputs[1], 0)]) + if isinstance(op, AdvancedSubtensor): + indices = indices_from_subtensor(var.owner.inputs[1:], op.idx_list) + # Collect consecutive advanced (tensor) indices + adv = [] + for i, idx in enumerate(indices): + if idx == slice(None): + if adv: + break # trailing slices after the advanced group + elif hasattr(idx, "ndim") and idx.ndim >= 1 and idx.dtype != "bool": + if adv and adv[-1][1] != i - 1: + return None # non-consecutive + adv.append((idx, i)) + else: + return None # unsupported index type + if not adv: + return None + # Verify only full slices remain after the advanced group + for idx in indices[adv[-1][1] + 1 :]: + if idx != slice(None): + return None + return (var.owner.inputs[0], adv) return None def find_indexed_input_groups(fgraph, node): """Find single-client indexed-read inputs grouped by (index, axis). Returns ``[(idx_var, axis, (pos, ...))]`` -- one entry per distinct - ``(idx_var, axis)`` pair. + ``(idx_var, axis)`` pair. A multi-index input contributes multiple + entries (one per indexed axis). """ groups = {} # (idx_var, axis) -> (idx_var, axis, list of positions) for i, inp in enumerate(node.inputs): @@ -165,39 +270,93 @@ def find_indexed_input_groups(fgraph, node): return [(var, axis, tuple(pos)) for var, axis, pos in groups.values()] + def _get_indexed_update_info(client_node): + """Extract indexed-update info from an AdvancedInc node. + + Returns ``(target, [(idx_var, axis), ...], mode)`` or ``None``. + """ + op = client_node.op + if isinstance(op, AdvancedIncSubtensor1): + target, _val, idx_var = client_node.inputs + mode = "set" if op.set_instead_of_inc else "inc" + return (target, [(idx_var, 0)], mode) + if isinstance(op, AdvancedIncSubtensor): + target = client_node.inputs[0] + _val = client_node.inputs[1] + index_vars = client_node.inputs[2:] + indices = indices_from_subtensor(index_vars, op.idx_list) + adv = [] + for i, idx in enumerate(indices): + if idx == slice(None): + if adv: + break + elif hasattr(idx, "ndim") and idx.ndim >= 1 and idx.dtype != "bool": + if adv and adv[-1][1] != i - 1: + return None + adv.append((idx, i)) + else: + return None + if not adv: + return None + for idx in indices[adv[-1][1] + 1 :]: + if idx != slice(None): + return None + mode = "set" if op.set_instead_of_inc else "inc" + return (target, adv, mode) + return None + def find_indexed_update_consumers(fgraph, node): - """Find AdvancedIncSubtensor1 consumers of Elemwise outputs. + """Find indexed-update consumers of Elemwise outputs. - Returns ``{out_idx: (update_node, target, idx_var, mode)}``. + Returns ``{out_idx: (update_node, target, [(idx_var, axis), ...], mode)}``. Only considers outputs that are the value input (position 1) of the indexed update. """ update_info = {} for out_idx, out in enumerate(node.outputs): clients = fgraph.clients[out] - # Find AdvancedIncSubtensor1 client at val position (1) inc_clients = [ (c, ci) for c, ci in clients - if ci == 1 and isinstance(c.op, AdvancedIncSubtensor1) + if ci == 1 + and isinstance(c.op, AdvancedIncSubtensor1 | AdvancedIncSubtensor) ] if len(inc_clients) != 1: continue - [(client_node, _client_inp_idx)] = inc_clients - target, val, idx_var = client_node.inputs + [(client_node, _)] = inc_clients + info = _get_indexed_update_info(client_node) + if info is None: + continue + target, idx_axis_pairs, mode = info # Don't fuse if the value broadcasts on the index loop dim # (constant across index — recomputing per position is wasteful) # or against non-indexed target axes. - if val.type.broadcastable[0]: + val = client_node.inputs[1] + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) + val_idx_bc = list(val.type.broadcastable)[:n_idx_dims] + if any(val_idx_bc): continue - val_bc = val.type.broadcastable[1:] - target_bc = target.type.broadcastable[1:] - if len(val_bc) < len(target_bc) or any( - vbc and not tbc for vbc, tbc in zip(val_bc, target_bc, strict=False) + indexed_axes = {a for _, a in idx_axis_pairs} + non_indexed_target_bc = [ + bc + for i, bc in enumerate(target.type.broadcastable) + if i not in indexed_axes + ] + non_indexed_val_bc = list(val.type.broadcastable) + non_indexed_val_bc = non_indexed_val_bc[n_idx_dims:] + if len(non_indexed_val_bc) < len(non_indexed_target_bc) or any( + vbc and not tbc + for vbc, tbc in zip( + non_indexed_val_bc, non_indexed_target_bc, strict=False + ) ): continue - mode = "set" if client_node.op.set_instead_of_inc else "inc" - update_info[out_idx] = (client_node, target, idx_var, mode) + update_info[out_idx] = ( + client_node, + target, + idx_axis_pairs, + mode, + ) return update_info for node in reversed(fgraph.toposort()): @@ -244,10 +403,11 @@ def find_indexed_update_consumers(fgraph, node): key = (idx_var, _ax) if key not in all_idx_groups: all_idx_groups[key] = (idx_var, len(all_idx_groups)) - for _un, _target, idx_var, _mode in update_consumers.values(): - key = (idx_var, 0) # updates are axis 0 for now - if key not in all_idx_groups: - all_idx_groups[key] = (idx_var, len(all_idx_groups)) + for _un, _target, idx_axis_pairs, _mode in update_consumers.values(): + for idx_var, axis in idx_axis_pairs: + key = (idx_var, axis) + if key not in all_idx_groups: + all_idx_groups[key] = (idx_var, len(all_idx_groups)) n_indices = len(all_idx_groups) idx_vars = [None] * n_indices @@ -319,7 +479,7 @@ def find_indexed_update_consumers(fgraph, node): inner_outputs = list(node.outputs) call_inputs = list(inner_inputs) for out_idx in sorted(update_consumers.keys()): - update_node, target, idx_var, _mode = update_consumers[out_idx] + update_node, target, idx_axis_pairs, _mode = update_consumers[out_idx] inner_inputs.append(target) @@ -333,15 +493,31 @@ def find_indexed_update_consumers(fgraph, node): scatter_value = node.outputs[scatter_idx] if update_node.op.inplace: - scatter_out = update_node.op( + if isinstance(update_node.op, AdvancedIncSubtensor1): + scatter_out = update_node.op( + target, scatter_value, update_node.inputs[2] + ) + else: + scatter_out = update_node.op( + target, scatter_value, *update_node.inputs[2:] + ) + elif isinstance(update_node.op, AdvancedIncSubtensor1): + inplace_op = AdvancedIncSubtensor1( + inplace=True, + set_instead_of_inc=update_node.op.set_instead_of_inc, + ) + scatter_out = inplace_op( target, scatter_value, update_node.inputs[2] ) else: - inplace_op = AdvancedIncSubtensor1( + inplace_op = AdvancedIncSubtensor( + idx_list=update_node.op.idx_list, inplace=True, set_instead_of_inc=update_node.op.set_instead_of_inc, ) - scatter_out = inplace_op(target, scatter_value, idx_var) + scatter_out = inplace_op( + target, scatter_value, *update_node.inputs[2:] + ) inner_outputs[scatter_idx] = scatter_out outer_destroy_map[scatter_idx] = [target_pos] @@ -367,14 +543,15 @@ def find_indexed_update_consumers(fgraph, node): # Build indexed_outputs spec for the Op indexed_outputs_spec = [None] * n_indices for out_idx in sorted(update_consumers.keys()): - _update_node, _target, idx_var, mode = update_consumers[out_idx] - key = (idx_var, 0) - idx_pos = all_idx_groups[key][1] - scatter_idx = dup_map.get(out_idx, out_idx) - if indexed_outputs_spec[idx_pos] is None: - indexed_outputs_spec[idx_pos] = ([scatter_idx], mode, 0) - else: - indexed_outputs_spec[idx_pos][0].append(scatter_idx) + _update_node, _target, idx_axis_pairs, mode = update_consumers[out_idx] + for idx_var, axis in idx_axis_pairs: + key = (idx_var, axis) + idx_pos = all_idx_groups[key][1] + scatter_idx = dup_map.get(out_idx, out_idx) + if indexed_outputs_spec[idx_pos] is None: + indexed_outputs_spec[idx_pos] = ([scatter_idx], mode, axis) + else: + indexed_outputs_spec[idx_pos][0].append(scatter_idx) indexed_outputs_spec = tuple( (tuple(e[0]), e[1], e[2]) if e is not None else None for e in indexed_outputs_spec diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 049062d118..de0988d2d3 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -8,11 +8,7 @@ from pytensor import config from pytensor.compile.mode import get_mode from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise -from pytensor.tensor.subtensor import ( - AdvancedIncSubtensor1, - AdvancedSubtensor, - advanced_subtensor1, -) +from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 numba = pytest.importorskip("numba") @@ -130,121 +126,6 @@ def test_broadcast_index_axis1(self): yv = np.ones((3, 50)) np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) - def test_nd_index_axis0(self): - """2D matrix index on axis 0.""" - rng = np.random.default_rng(42) - x = pt.vector("x", shape=(100,)) - mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) - y = pt.matrix("y", shape=(10, 5)) - fn, fn_u = fused_and_unfused([x, mat_idx, y], pt.exp(x[mat_idx]) + y) - assert_fused(fn) - xv = rng.normal(size=(100,)) - iv = rng.integers(100, size=(10, 5)).astype(np.int64) - yv = rng.normal(size=(10, 5)) - np.testing.assert_allclose(fn(xv, iv, yv), fn_u(xv, iv, yv), rtol=1e-10) - - def test_nd_index_axis1(self): - """2D matrix index on axis 1 (via undo_take_reshape_for_fusion).""" - rng = np.random.default_rng(42) - x = pt.matrix("x", shape=(3, 100)) - mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) - fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[:, mat_idx])) - assert_fused(fn) - xv = rng.normal(size=(3, 100)) - iv = rng.integers(100, size=(10, 5)).astype(np.int64) - np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) - - def test_nd_index_with_trailing_dims(self): - """2D index on axis 0 with trailing source dims.""" - rng = np.random.default_rng(42) - x = pt.matrix("x", shape=(100, 7)) - mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) - fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[mat_idx])) - assert_fused(fn) - xv = rng.normal(size=(100, 7)) - iv = rng.integers(100, size=(10, 5)).astype(np.int64) - np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) - - def test_nd_index_broadcast(self): - """Broadcastable 2D index (shape (1, C)) broadcasts against direct input.""" - rng = np.random.default_rng(42) - x = pt.vector("x", shape=(100,)) - bc_idx = pt.matrix("bc_idx", dtype="int64", shape=(1, 5)) - y = pt.matrix("y", shape=(10, 5)) - fn, fn_u = fused_and_unfused([x, bc_idx, y], x[bc_idx] + y) - assert_fused(fn) - xv = rng.normal(size=(100,)) - bcv = rng.integers(100, size=(1, 5)).astype(np.int64) - yv = rng.normal(size=(10, 5)) - np.testing.assert_allclose(fn(xv, bcv, yv), fn_u(xv, bcv, yv), rtol=1e-10) - - def test_scalar_and_vector_index(self): - """x[scalar_idx, vector_idx] — 0-d index broadcasts with 1-d index.""" - rng = np.random.default_rng(42) - x = pt.matrix("x", shape=(100, 200)) - scalar_idx = pt.scalar("scalar_idx", dtype="int64") - vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) - y = pt.vector("y", shape=(50,)) - fn, fn_u = fused_and_unfused( - [x, scalar_idx, vec_idx, y], x[scalar_idx, vec_idx] + y - ) - assert_fused(fn) - xv = rng.normal(size=(100, 200)) - sv = np.array(rng.integers(100), dtype=np.int64) - vv = rng.integers(200, size=50).astype(np.int64) - yv = rng.normal(size=50) - np.testing.assert_allclose(fn(xv, sv, vv, yv), fn_u(xv, sv, vv, yv), rtol=1e-10) - - def test_vector_and_scalar_index(self): - """x[vector_idx, scalar_idx] — reversed order.""" - rng = np.random.default_rng(42) - x = pt.matrix("x", shape=(100, 200)) - vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) - scalar_idx = pt.scalar("scalar_idx", dtype="int64") - y = pt.vector("y", shape=(50,)) - fn, fn_u = fused_and_unfused( - [x, vec_idx, scalar_idx, y], x[vec_idx, scalar_idx] + y - ) - assert_fused(fn) - xv = rng.normal(size=(100, 200)) - vv = rng.integers(100, size=50).astype(np.int64) - sv = np.array(rng.integers(200), dtype=np.int64) - yv = rng.normal(size=50) - np.testing.assert_allclose(fn(xv, vv, sv, yv), fn_u(xv, vv, sv, yv), rtol=1e-10) - - def test_scalar_and_vector_index_with_trailing_dim(self): - """x[scalar_idx, vector_idx] on a 3-d tensor with trailing dim.""" - rng = np.random.default_rng(42) - x = pt.tensor3("x", shape=(100, 200, 7)) - scalar_idx = pt.scalar("scalar_idx", dtype="int64") - vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) - z = pt.matrix("z", shape=(50, 7)) - fn, fn_u = fused_and_unfused( - [x, scalar_idx, vec_idx, z], x[scalar_idx, vec_idx] + z - ) - assert_fused(fn) - xv = rng.normal(size=(100, 200, 7)) - sv = np.array(rng.integers(100), dtype=np.int64) - vv = rng.integers(200, size=50).astype(np.int64) - zv = rng.normal(size=(50, 7)) - np.testing.assert_allclose(fn(xv, sv, vv, zv), fn_u(xv, sv, vv, zv), rtol=1e-10) - - def test_all_scalar_indices(self): - """AdvancedSubtensor with all 0-d indices (degenerate case).""" - rng = np.random.default_rng(42) - x = pt.matrix("x", shape=(100, 200)) - si0 = pt.scalar("si0", dtype="int64") - si1 = pt.scalar("si1", dtype="int64") - y = pt.scalar("y", dtype="float64") - adv_sub = AdvancedSubtensor(idx_list=(0, 1))(x, si0, si1) - fn, fn_u = fused_and_unfused([x, si0, si1, y], adv_sub + y) - assert_fused(fn) - xv = rng.normal(size=(100, 200)) - s0 = np.array(rng.integers(100), dtype=np.int64) - s1 = np.array(rng.integers(200), dtype=np.int64) - yv = np.array(rng.normal()) - np.testing.assert_allclose(fn(xv, s0, s1, yv), fn_u(xv, s0, s1, yv), rtol=1e-10) - def test_negative_indices(self): """Negative indices must be handled correctly (sign-extended, not zero-extended).""" rng = np.random.default_rng(42) @@ -269,13 +150,10 @@ def test_negative_indices(self): class TestIndexedUpdateFusion: """Test indexed updates (AdvancedIncSubtensor1) fused into Elemwise.""" - def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): - """Don't fuse if the indexed axes are not within the Elemwise loop. + def test_no_fusion_when_val_broadcasts_against_target(self): + """Don't fuse (yet) if elemwise output broadcasts against target's trailing axes. - Here the index is on axis 0 of target(5, 10), but the Elemwise - output (10,) corresponds to axis 1 (the non-indexed trailing axis). - The indexed axis doesn't overlap with the Elemwise computation, so - fusing would misalign which input dims map to which target dims. + TODO: support this by making the update output's core_ndim > 0. """ rng = np.random.default_rng(42) idx = rng.integers(5, size=10).astype(np.int64) @@ -285,9 +163,9 @@ def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): elemwise_out = advanced_subtensor1(x, idx) + y # shape (10,) out = pt.inc_subtensor(target[idx], elemwise_out) fn, fn_u = fused_and_unfused([x, y, target], out) - # Write not fused — the Elemwise loop dim is the non-indexed axis, - # not the indexed axis. Read fusion may still create an - # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. + # Write not fused — val (10,) would broadcast to (10, 10) in the target. + # Read fusion still creates an IndexedElemwise, but the + # AdvancedIncSubtensor1 must remain outside it. assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() ) @@ -317,35 +195,6 @@ def test_no_fusion_when_val_broadcasts_along_index_dim(self): tv = rng.normal(size=(5,)) np.testing.assert_allclose(fn(xv, tv), fn_u(xv, tv), rtol=1e-10) - def test_broadcast_val_into_non_indexed_dims(self): - """Fuse when Elemwise output broadcasts into target's non-indexed dims. - - target(5, 3)[:, idx] += exp(val) — the Elemwise loop covers the - index axis (axis 1), and the scalar result broadcasts into the - non-indexed leading axis (axis 0, core_ndim=1). - - Requires excluding local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 - so we keep AdvancedIncSubtensor on the correct axis. - """ - rng = np.random.default_rng(42) - idx = rng.integers(3, size=10).astype(np.int64) - target = pt.matrix("target", shape=(5, 3)) - val = pt.vector("val", shape=(10,)) - out = pt.inc_subtensor(target[:, idx], pt.exp(val)) - - mode = NUMBA_MODE.excluding( - "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" - ) - mode_u = NUMBA_NO_FUSION.excluding( - "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" - ) - fn = pytensor.function([val, target], out, mode=mode, trust_input=True) - fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True) - assert_fused(fn) - valv = rng.normal(size=(10,)) - tv = rng.normal(size=(5, 3)) - np.testing.assert_allclose(fn(valv, tv), fn_u(valv, tv), rtol=1e-10) - def test_inc_subtensor(self): rng = np.random.default_rng(42) idx = rng.integers(85, size=919).astype(np.int64) @@ -458,22 +307,6 @@ def test_combined_multi_index_read_write(self): fn(tv, sv, rv, cv, xv), fn_u(tv, sv, rv, cv, xv), rtol=1e-10 ) - def test_scalar_and_vector_index_inc(self): - """inc_subtensor with x[scalar_idx, vector_idx] — 0-d and 1-d indices.""" - rng = np.random.default_rng(42) - target = pt.matrix("target", shape=(100, 200)) - scalar_idx = pt.scalar("scalar_idx", dtype="int64") - vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) - x = pt.vector("x", shape=(50,)) - out = pt.inc_subtensor(target[scalar_idx, vec_idx], pt.exp(x)) - fn, fn_u = fused_and_unfused([target, scalar_idx, vec_idx, x], out) - assert_fused(fn) - tv = rng.normal(size=(100, 200)) - sv = np.array(rng.integers(100), dtype=np.int64) - vv = rng.integers(200, size=50).astype(np.int64) - xv = rng.normal(size=50) - np.testing.assert_allclose(fn(tv, sv, vv, xv), fn_u(tv, sv, vv, xv), rtol=1e-10) - class TestShapeValidation: """Test that mismatched index/input shapes raise runtime errors. @@ -578,20 +411,6 @@ def test_shared_read_write_index_no_broadcast(self): np.array([0], dtype=np.int64), ) - def test_mismatched_nd_index_dims(self): - """ND index shape doesn't match direct input shape.""" - x = pt.vector("x", shape=(None,)) - mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(None, None)) - y = pt.matrix("y", shape=(None, None)) - out = x[mat_idx] + y - fn = pytensor.function([x, mat_idx, y], out, mode=NUMBA_MODE, trust_input=True) - assert_fused(fn) - # Matching: mat_idx=(3,4), y=(3,4) — should work - fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 4))) - # Mismatched: mat_idx=(3,4), y=(3,5) — should error - with pytest.raises(Exception): - fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 5))) - # ============================================================ # Radon model integration test # ============================================================ From f339cdb775a200bd448f42b0f2ab97212e774402 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 1 Apr 2026 15:05:35 +0200 Subject: [PATCH 5/7] Numba: fuse ND and 0-d integer indices into IndexedElemwise Support multidimensional (e.g. 2D matrix) and 0-d integer indices in IndexedElemwise fusion, for both reads and writes. ND indices: - Add undo_take_reshape_for_fusion: undoes the Reshape+flatten pattern that transform_take applies for ND indices, recovering the original AdvancedSubtensor(source, mat_idx) form for fusion. Handles both axis=0 and axis>0 (with DimShuffle wrapping). - idx_load_axes: tuple of tuples, each index array loads from idx_ndim loop counters 0-d indices: - Accept 0-d tensor indices (e.g. x[scalar_idx, vec_idx]) which are valid AdvancedSubtensor inputs that broadcast with higher-dim indices. --- pytensor/link/numba/dispatch/elemwise.py | 2 +- .../link/numba/dispatch/vectorize_codegen.py | 11 +- pytensor/tensor/rewriting/indexed_elemwise.py | 130 ++++++++++++++- tests/link/numba/test_indexed_elemwise.py | 151 +++++++++++++++++- 4 files changed, 285 insertions(+), 9 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index f808f99f87..97d276c373 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -489,7 +489,7 @@ def impl(*outer_inputs): return impl - cache_version = (0, 3) + cache_version = (0, 4) if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 81b858cec3..999fa27c72 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -940,6 +940,7 @@ def codegen(ctx, builder, sig, args): # loop dim). Replace the source shape with a constructed shape that # matches the bc pattern: one entry per loop dim, with index lengths # substituted for the indexed loop dim(s). + one = ir.IntType(64)(1) iter_shapes = list(in_shapes) iter_bc = list(input_bc_patterns_val) idx_shapes = [ @@ -966,19 +967,24 @@ def codegen(ctx, builder, sig, args): indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} max_axis = max(a for _, a in spec) new_shape = [] + new_bc = [] src_d = 0 idx_d = 0 for loop_d in range(batch_ndim): if src_d in indexed_axes and idx_d < n_index_loop_dims: - idx_k_first = spec[0][0] - new_shape.append(idx_shapes[idx_k_first][idx_d]) + # Placeholder — actual index shapes are contributed + # separately by each index array's iter_shape entry. + new_shape.append(one) + new_bc.append(True) idx_d += 1 if idx_d >= n_index_loop_dims: src_d = max_axis + 1 else: new_shape.append(source_shape[src_d]) + new_bc.append(iter_bc[p][loop_d]) src_d += 1 iter_shapes[p] = new_shape + iter_bc[p] = tuple(new_bc) # Each index array participates in iter_shape validation. # @@ -989,7 +995,6 @@ def codegen(ctx, builder, sig, args): # repeat writes. In that case we force bc=False so # compute_itershape requires the index length to match the loop. batch_ndim = len(input_bc_patterns_val[0]) if input_bc_patterns_val else 0 - one = ir.IntType(64)(1) # Per loop dim: is every write index broadcastable? _write_all_bc = [True] * batch_ndim diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index 8e23fb5bf4..7670bea42a 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -15,6 +15,7 @@ from pytensor.scalar.basic import Composite, identity from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.shape import Reshape from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -61,7 +62,9 @@ def undo_take_dimshuffle_for_fusion(fgraph, node): return None asub1_node = inner.owner - # AdvancedSubtensor1's input must be an inverse DimShuffle (same swap) + # AdvancedSubtensor1's input must be the inverse DimShuffle. + # For a pair swap, the inverse is the same permutation (swap is self-inverse). + # A general permutation would need argsort(order) here. inner_ds_var = asub1_node.inputs[0] if inner_ds_var.owner is None or not isinstance(inner_ds_var.owner.op, DimShuffle): return None @@ -93,6 +96,106 @@ def undo_take_dimshuffle_for_fusion(fgraph, node): return [new_out] +@node_rewriter([Reshape]) +def undo_take_reshape_for_fusion(fgraph, node): + """Undo ``Reshape(AdvancedSubtensor1(x, flatten(idx)), shape)`` for ND indices. + + ``transform_take`` rewrites ``x[mat_idx]`` (ND integer index) into + ``AdvancedSubtensor1(x, mat_idx.ravel()).reshape(mat_idx.shape + ...)``, + possibly with DimShuffle axis-swaps for non-zero axes. This rewrite + undoes that so ``FuseIndexedElemwise`` can absorb the ND index directly. + """ + [reshape_out] = node.outputs + + # Must feed a single Elemwise (or chain to one via another pre-fusion rewrite) + clients = fgraph.clients[reshape_out] + if len(clients) != 1: + return None + client_node, _ = clients[0] + if not isinstance(getattr(client_node, "op", None), Elemwise): + return None + + inner = node.inputs[0] + if inner.owner is None: + return None + + # --- Detect axis-0 pattern: Reshape(AdvancedSubtensor1(src, flatten(idx)), shape) + # --- Detect axis>0 pattern: Reshape(DimShuffle(AdvancedSubtensor1(DimShuffle(src), flatten(idx))), shape) + axis = 0 + asub1_node = None + + if isinstance(inner.owner.op, AdvancedSubtensor1): + asub1_node = inner.owner + elif isinstance(inner.owner.op, DimShuffle): + # Check for axis-swap DimShuffle wrapping AdvancedSubtensor1 + outer_ds = inner.owner.op + if outer_ds.augment or outer_ds.drop: + return None + order = outer_ds.new_order + ndim_ds = len(order) + if ndim_ds < 2: + return None + swapped = [i for i in range(ndim_ds) if order[i] != i] + if len(swapped) != 2: + return None + ax_a, ax_b = swapped + if order[ax_a] != ax_b or order[ax_b] != ax_a: + return None + axis = max(ax_a, ax_b) + + ds_inner = inner.owner.inputs[0] + if ds_inner.owner is None or not isinstance( + ds_inner.owner.op, AdvancedSubtensor1 + ): + return None + asub1_node = ds_inner.owner + + # AdvancedSubtensor1's source must be the inverse DimShuffle + src_var = asub1_node.inputs[0] + if src_var.owner is None or not isinstance(src_var.owner.op, DimShuffle): + return None + if src_var.owner.op.new_order != tuple(order): + return None + + # Intermediates must be single-client + if len(fgraph.clients[ds_inner]) != 1: + return None + if len(fgraph.clients[src_var]) != 1: + return None + else: + return None + + if asub1_node is None: + return None + + # The index input to AdvancedSubtensor1 must be Reshape{1}(mat_idx, [-1]) (flatten) + flat_idx = asub1_node.inputs[1] + if flat_idx.owner is None or not isinstance(flat_idx.owner.op, Reshape): + return None + if flat_idx.owner.op.ndim != 1: + return None + mat_idx = flat_idx.owner.inputs[0] + if mat_idx.ndim < 2: + return None + + # AdvancedSubtensor1 output must be single-client + if len(fgraph.clients[asub1_node.outputs[0]]) != 1: + return None + + # Recover the original source (unwrap inner DimShuffle if axis > 0) + if axis > 0: + source = asub1_node.inputs[0].owner.inputs[0] + else: + source = asub1_node.inputs[0] + + # Build AdvancedSubtensor: source[:, ..., mat_idx, :, ...] + src_ndim = source.type.ndim + idx_list = [slice(None)] * src_ndim + idx_list[axis] = 0 # pointer to the single index variable + new_out = AdvancedSubtensor(idx_list=idx_list)(source, mat_idx) + return [new_out] + + indexed_elemwise_optdb = SequenceDB() optdb.register( "fuse_indexed_into_elemwise", @@ -110,6 +213,13 @@ def undo_take_dimshuffle_for_fusion(fgraph, node): position=0, ) +indexed_elemwise_optdb.register( + "undo_take_reshape_for_fusion", + dfs_rewriter(undo_take_reshape_for_fusion), + "numba", + position=0.5, +) + class IndexedElemwise(OpFromGraph): """Fuse indexed reads and updates into a single Elemwise iteration loop. @@ -232,7 +342,7 @@ def _get_indexed_read_info(var): if idx == slice(None): if adv: break # trailing slices after the advanced group - elif hasattr(idx, "ndim") and idx.ndim >= 1 and idx.dtype != "bool": + elif hasattr(idx, "ndim") and idx.dtype != "bool": if adv and adv[-1][1] != i - 1: return None # non-consecutive adv.append((idx, i)) @@ -290,7 +400,7 @@ def _get_indexed_update_info(client_node): if idx == slice(None): if adv: break - elif hasattr(idx, "ndim") and idx.ndim >= 1 and idx.dtype != "bool": + elif hasattr(idx, "ndim") and idx.dtype != "bool": if adv and adv[-1][1] != i - 1: return None adv.append((idx, i)) @@ -441,6 +551,10 @@ def find_indexed_update_consumers(fgraph, node): multi_client_outs.add(out_idx) if multi_client_outs: + # Rebuild the Elemwise with duplicated outputs. + # e.g. Exp(x) -> Composite([x], [exp(x), exp(x)])(x) + # so the Elemwise produces [out0, out1] where out1 is + # available for the scatter to replace. scalar_op = node.op.scalar_op if isinstance(scalar_op, Composite): s_inputs = list(scalar_op.inputs) @@ -452,6 +566,7 @@ def find_indexed_update_consumers(fgraph, node): s_inputs = list(scalar_node.inputs) s_outputs = list(scalar_node.outputs) + # Map from original out_idx to the new duplicate out_idx. # Wrap duplicates with identity so Composite._cleanup_graph # doesn't clone the entire subgraph for repeated outputs. # TODO: _cleanup_graph should use identity instead of clone @@ -464,9 +579,11 @@ def find_indexed_update_consumers(fgraph, node): new_scalar_op = Composite(s_inputs, s_outputs) new_elemwise = Elemwise(new_scalar_op)(*node.inputs, return_list=True) + # Update node reference and remap outputs old_node = node node = new_elemwise[0].owner + # Rebuild inner_inputs with the new node's inputs inner_inputs = [ inp.owner.inputs[0] if i in indexed_positions else inp for i, inp in enumerate(node.inputs) @@ -489,10 +606,14 @@ def find_indexed_update_consumers(fgraph, node): else: call_inputs.append(target.copy()) + # Use the duplicate output for scatter if multi-client scatter_idx = dup_map.get(out_idx, out_idx) + # The value to scatter: use the (possibly duplicated) Elemwise output scatter_value = node.outputs[scatter_idx] + # Build the scatter output using the correct value if update_node.op.inplace: + # Rebuild with the new value if isinstance(update_node.op, AdvancedIncSubtensor1): scatter_out = update_node.op( target, scatter_value, update_node.inputs[2] @@ -518,7 +639,6 @@ def find_indexed_update_consumers(fgraph, node): scatter_out = inplace_op( target, scatter_value, *update_node.inputs[2:] ) - inner_outputs[scatter_idx] = scatter_out outer_destroy_map[scatter_idx] = [target_pos] @@ -565,6 +685,7 @@ def find_indexed_update_consumers(fgraph, node): indexed_outputs=indexed_outputs_spec, )(*call_inputs, return_list=True) + # The node whose outputs we need to replace in the outer graph orig_node = old_node if multi_client_outs else node replacements = [] for out_idx in range(len(orig_node.outputs)): @@ -573,6 +694,7 @@ def find_indexed_update_consumers(fgraph, node): scatter_idx = dup_map.get(out_idx, out_idx) replacements.append((update_node.outputs[0], new_outs[scatter_idx])) if out_idx in dup_map: + # Multi-client: also replace the raw elemwise output replacements.append( (orig_node.outputs[out_idx], new_outs[out_idx]) ) diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index de0988d2d3..c0773ee0fc 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -8,7 +8,11 @@ from pytensor import config from pytensor.compile.mode import get_mode from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise -from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor1, + AdvancedSubtensor, + advanced_subtensor1, +) numba = pytest.importorskip("numba") @@ -126,6 +130,121 @@ def test_broadcast_index_axis1(self): yv = np.ones((3, 50)) np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + def test_nd_index_axis0(self): + """2D matrix index on axis 0.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx, y], pt.exp(x[mat_idx]) + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, iv, yv), fn_u(xv, iv, yv), rtol=1e-10) + + def test_nd_index_axis1(self): + """2D matrix index on axis 1 (via undo_take_reshape_for_fusion).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(3, 100)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[:, mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(3, 100)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_with_trailing_dims(self): + """2D index on axis 0 with trailing source dims.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 7)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(100, 7)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_broadcast(self): + """Broadcastable 2D index (shape (1, C)) broadcasts against direct input.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + bc_idx = pt.matrix("bc_idx", dtype="int64", shape=(1, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, bc_idx, y], x[bc_idx] + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + bcv = rng.integers(100, size=(1, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, bcv, yv), fn_u(xv, bcv, yv), rtol=1e-10) + + def test_scalar_and_vector_index(self): + """x[scalar_idx, vector_idx] — 0-d index broadcasts with 1-d index.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, y], x[scalar_idx, vec_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, sv, vv, yv), fn_u(xv, sv, vv, yv), rtol=1e-10) + + def test_vector_and_scalar_index(self): + """x[vector_idx, scalar_idx] — reversed order.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, vec_idx, scalar_idx, y], x[vec_idx, scalar_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + vv = rng.integers(100, size=50).astype(np.int64) + sv = np.array(rng.integers(200), dtype=np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, vv, sv, yv), fn_u(xv, vv, sv, yv), rtol=1e-10) + + def test_scalar_and_vector_index_with_trailing_dim(self): + """x[scalar_idx, vector_idx] on a 3-d tensor with trailing dim.""" + rng = np.random.default_rng(42) + x = pt.tensor3("x", shape=(100, 200, 7)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + z = pt.matrix("z", shape=(50, 7)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, z], x[scalar_idx, vec_idx] + z + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200, 7)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + zv = rng.normal(size=(50, 7)) + np.testing.assert_allclose(fn(xv, sv, vv, zv), fn_u(xv, sv, vv, zv), rtol=1e-10) + + def test_all_scalar_indices(self): + """AdvancedSubtensor with all 0-d indices (degenerate case).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + si0 = pt.scalar("si0", dtype="int64") + si1 = pt.scalar("si1", dtype="int64") + y = pt.scalar("y", dtype="float64") + adv_sub = AdvancedSubtensor(idx_list=(0, 1))(x, si0, si1) + fn, fn_u = fused_and_unfused([x, si0, si1, y], adv_sub + y) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + s0 = np.array(rng.integers(100), dtype=np.int64) + s1 = np.array(rng.integers(200), dtype=np.int64) + yv = np.array(rng.normal()) + np.testing.assert_allclose(fn(xv, s0, s1, yv), fn_u(xv, s0, s1, yv), rtol=1e-10) + def test_negative_indices(self): """Negative indices must be handled correctly (sign-extended, not zero-extended).""" rng = np.random.default_rng(42) @@ -307,6 +426,22 @@ def test_combined_multi_index_read_write(self): fn(tv, sv, rv, cv, xv), fn_u(tv, sv, rv, cv, xv), rtol=1e-10 ) + def test_scalar_and_vector_index_inc(self): + """inc_subtensor with x[scalar_idx, vector_idx] — 0-d and 1-d indices.""" + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[scalar_idx, vec_idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, scalar_idx, vec_idx, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, sv, vv, xv), fn_u(tv, sv, vv, xv), rtol=1e-10) + class TestShapeValidation: """Test that mismatched index/input shapes raise runtime errors. @@ -411,6 +546,20 @@ def test_shared_read_write_index_no_broadcast(self): np.array([0], dtype=np.int64), ) + def test_mismatched_nd_index_dims(self): + """ND index shape doesn't match direct input shape.""" + x = pt.vector("x", shape=(None,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(None, None)) + y = pt.matrix("y", shape=(None, None)) + out = x[mat_idx] + y + fn = pytensor.function([x, mat_idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: mat_idx=(3,4), y=(3,4) — should work + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 4))) + # Mismatched: mat_idx=(3,4), y=(3,5) — should error + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 5))) + # ============================================================ # Radon model integration test # ============================================================ From 10846566c34c53033bae9bd33d19623054ec8803 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 1 Apr 2026 22:54:34 +0200 Subject: [PATCH 6/7] Numba: allow broadcast scatter in IndexedElemwise Allow the written buffer in IndexedElemwise to be larger than scalar in the core loop. When the Elemwise computes a scalar per index iteration and the target has extra non-indexed dims, the scalar broadcasts into those dims via o[...] = scalar. Key changes: - Generalize _getitem_0d_ellipsis to any ndim so o[...] += scalar works for 1-d+ core output arrays - Relax the broadcast guard to allow val with fewer non-indexed dims when the indexed axes are rightmost (no trailing non-indexed dims) - Transpose the target buffer in the fusion rewrite when non-indexed dims precede the indexed axis, so core dims are trailing (gufunc convention) --- pytensor/link/numba/dispatch/elemwise.py | 2 +- .../link/numba/dispatch/vectorize_codegen.py | 21 ++-- pytensor/tensor/rewriting/indexed_elemwise.py | 119 +++++++++++++++--- tests/link/numba/test_indexed_elemwise.py | 44 ++++++- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 97d276c373..73bc56a324 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -489,7 +489,7 @@ def impl(*outer_inputs): return impl - cache_version = (0, 4) + cache_version = (0, 5) if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 999fa27c72..cf03d4ed51 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -21,20 +21,23 @@ from pytensor.link.numba.dispatch import basic as numba_basic -# Numba is missing getitem(0d_array, Ellipsis), so o[...] += val fails. +# Numba is missing getitem(array, Ellipsis), so o[...] += val fails. # Register it so store_core_outputs can use o[...] += val naturally. @overload(operator.getitem, inline="always") def _getitem_0d_ellipsis(arr, idx): - if ( - isinstance(arr, types.Array) - and arr.ndim == 0 - and isinstance(idx, types.EllipsisType) - ): + if isinstance(arr, types.Array) and isinstance(idx, types.EllipsisType): + if arr.ndim == 0: + + def impl(arr, idx): + return arr[()] + + return impl + else: - def impl(arr, idx): - return arr[()] + def impl(arr, idx): + return arr - return impl + return impl def encode_literals(literals: Sequence) -> str: diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index 7670bea42a..580a975bd3 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -438,23 +438,51 @@ def find_indexed_update_consumers(fgraph, node): if info is None: continue target, idx_axis_pairs, mode = info - # Don't fuse if the value broadcasts on the index loop dim - # (constant across index — recomputing per position is wasteful) - # or against non-indexed target axes. + # Don't fuse if val is broadcastable on any index loop dim — + # the Elemwise result is constant across the index and + # recomputing it per index position is wasteful. val = client_node.inputs[1] n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) val_idx_bc = list(val.type.broadcastable)[:n_idx_dims] if any(val_idx_bc): continue + + # Check broadcast compatibility between val and target on + # non-indexed axes. Val may have fewer non-indexed dims + # than target — the extra target dims become core dims + # (broadcast output) — but ONLY when the indexed axes are + # rightmost in the target so that val's dims align with the + # index loop (numpy broadcasts from the right). When + # indexed axes are NOT rightmost, val's dims would align + # with trailing non-indexed axes instead, giving wrong + # results if fused. indexed_axes = {a for _, a in idx_axis_pairs} + max_indexed_axis = max(indexed_axes) + has_trailing_non_indexed = any( + a > max_indexed_axis + for a in range(target.type.ndim) + if a not in indexed_axes + ) non_indexed_target_bc = [ bc for i, bc in enumerate(target.type.broadcastable) if i not in indexed_axes ] non_indexed_val_bc = list(val.type.broadcastable) + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) non_indexed_val_bc = non_indexed_val_bc[n_idx_dims:] - if len(non_indexed_val_bc) < len(non_indexed_target_bc) or any( + if len(non_indexed_val_bc) < len(non_indexed_target_bc): + # Val has fewer dims — broadcasting into extra target + # dims. Only safe when indexed axes are rightmost + # (no trailing non-indexed dims). + # TODO: could also handle trailing non-indexed dims + # that are broadcastable (size-1) in the val — squeeze + # them away before fusing. Requires excluding + # local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 + # to encounter these graphs naturally. + if has_trailing_non_indexed: + continue + elif any( vbc and not tbc for vbc, tbc in zip( non_indexed_val_bc, non_indexed_target_bc, strict=False @@ -593,15 +621,50 @@ def find_indexed_update_consumers(fgraph, node): dup_map = {} # Inner fgraph outputs; add update targets + # For non-axis-0 indexed updates, transpose the target so the + # indexed axis moves to position 0 and core dims are trailing + # (matching the gufunc convention assumed by the codegen). inner_outputs = list(node.outputs) call_inputs = list(inner_inputs) + scatter_transpose_back = {} # scatter_idx -> inverse perm for out_idx in sorted(update_consumers.keys()): update_node, target, idx_axis_pairs, _mode = update_consumers[out_idx] + # Check if we need to transpose the target + max_idx_axis = max(a for _, a in idx_axis_pairs) + needs_transpose = any( + a < max_idx_axis + for a in range(target.type.ndim) + if a not in {a for _, a in idx_axis_pairs} + ) + if needs_transpose: + # Build permutation: indexed axes first, then rest + idx_axes_sorted = sorted({a for _, a in idx_axis_pairs}) + non_idx_axes = [ + a for a in range(target.type.ndim) if a not in idx_axes_sorted + ] + perm = idx_axes_sorted + non_idx_axes + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + target = target.dimshuffle(perm) + # Remap idx_axis_pairs to the transposed axes + axis_remap = {old: new for new, old in enumerate(perm)} + idx_axis_pairs = [ + (idx_var, axis_remap[axis]) for idx_var, axis in idx_axis_pairs + ] + # Update the consumers entry with remapped axes + update_consumers[out_idx] = ( + update_node, + target, + idx_axis_pairs, + _mode, + ) + inner_inputs.append(target) target_pos = len(call_inputs) - if update_node.op.inplace: + if update_node.op.inplace and not needs_transpose: call_inputs.append(target) else: call_inputs.append(target.copy()) @@ -611,9 +674,20 @@ def find_indexed_update_consumers(fgraph, node): # The value to scatter: use the (possibly duplicated) Elemwise output scatter_value = node.outputs[scatter_idx] - # Build the scatter output using the correct value - if update_node.op.inplace: - # Rebuild with the new value + # Build the scatter output + # After transpose, indexed axes are at position 0+, + # so always use AdvancedIncSubtensor1 for single axis-0. + if len(idx_axis_pairs) == 1 and idx_axis_pairs[0][1] == 0: + inplace_op = AdvancedIncSubtensor1( + inplace=True, + set_instead_of_inc=update_node.op.set_instead_of_inc, + ) + scatter_out = inplace_op( + target, + scatter_value, + idx_axis_pairs[0][0], + ) + elif update_node.op.inplace and not needs_transpose: if isinstance(update_node.op, AdvancedIncSubtensor1): scatter_out = update_node.op( target, scatter_value, update_node.inputs[2] @@ -622,14 +696,6 @@ def find_indexed_update_consumers(fgraph, node): scatter_out = update_node.op( target, scatter_value, *update_node.inputs[2:] ) - elif isinstance(update_node.op, AdvancedIncSubtensor1): - inplace_op = AdvancedIncSubtensor1( - inplace=True, - set_instead_of_inc=update_node.op.set_instead_of_inc, - ) - scatter_out = inplace_op( - target, scatter_value, update_node.inputs[2] - ) else: inplace_op = AdvancedIncSubtensor( idx_list=update_node.op.idx_list, @@ -642,6 +708,9 @@ def find_indexed_update_consumers(fgraph, node): inner_outputs[scatter_idx] = scatter_out outer_destroy_map[scatter_idx] = [target_pos] + if needs_transpose: + scatter_transpose_back[scatter_idx] = inv_perm + # Build indexed_inputs spec for the Op indexed_inputs_spec = [None] * n_indices for idx_var, _ax, positions in read_groups: @@ -660,13 +729,21 @@ def find_indexed_update_consumers(fgraph, node): indexed_inputs_spec[k] = ((), ax, iv.type.broadcastable) break - # Build indexed_outputs spec for the Op + # Build indexed_outputs spec for the Op. + # Use the (possibly remapped) axis from update_consumers. indexed_outputs_spec = [None] * n_indices for out_idx in sorted(update_consumers.keys()): _update_node, _target, idx_axis_pairs, mode = update_consumers[out_idx] for idx_var, axis in idx_axis_pairs: - key = (idx_var, axis) - idx_pos = all_idx_groups[key][1] + # Look up in all_idx_groups using the ORIGINAL key + # (all_idx_groups was built before any transpose). + idx_pos = None + for (iv, _), (v, p) in all_idx_groups.items(): + if iv is idx_var and p is not None: + idx_pos = p + break + if idx_pos is None: + continue scatter_idx = dup_map.get(out_idx, out_idx) if indexed_outputs_spec[idx_pos] is None: indexed_outputs_spec[idx_pos] = ([scatter_idx], mode, axis) @@ -685,6 +762,10 @@ def find_indexed_update_consumers(fgraph, node): indexed_outputs=indexed_outputs_spec, )(*call_inputs, return_list=True) + # Transpose back any scatter outputs that were transposed + for scatter_idx, inv_perm in scatter_transpose_back.items(): + new_outs[scatter_idx] = new_outs[scatter_idx].dimshuffle(inv_perm) + # The node whose outputs we need to replace in the outer graph orig_node = old_node if multi_client_outs else node replacements = [] diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index c0773ee0fc..049062d118 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -269,10 +269,13 @@ def test_negative_indices(self): class TestIndexedUpdateFusion: """Test indexed updates (AdvancedIncSubtensor1) fused into Elemwise.""" - def test_no_fusion_when_val_broadcasts_against_target(self): - """Don't fuse (yet) if elemwise output broadcasts against target's trailing axes. + def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): + """Don't fuse if the indexed axes are not within the Elemwise loop. - TODO: support this by making the update output's core_ndim > 0. + Here the index is on axis 0 of target(5, 10), but the Elemwise + output (10,) corresponds to axis 1 (the non-indexed trailing axis). + The indexed axis doesn't overlap with the Elemwise computation, so + fusing would misalign which input dims map to which target dims. """ rng = np.random.default_rng(42) idx = rng.integers(5, size=10).astype(np.int64) @@ -282,9 +285,9 @@ def test_no_fusion_when_val_broadcasts_against_target(self): elemwise_out = advanced_subtensor1(x, idx) + y # shape (10,) out = pt.inc_subtensor(target[idx], elemwise_out) fn, fn_u = fused_and_unfused([x, y, target], out) - # Write not fused — val (10,) would broadcast to (10, 10) in the target. - # Read fusion still creates an IndexedElemwise, but the - # AdvancedIncSubtensor1 must remain outside it. + # Write not fused — the Elemwise loop dim is the non-indexed axis, + # not the indexed axis. Read fusion may still create an + # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() ) @@ -314,6 +317,35 @@ def test_no_fusion_when_val_broadcasts_along_index_dim(self): tv = rng.normal(size=(5,)) np.testing.assert_allclose(fn(xv, tv), fn_u(xv, tv), rtol=1e-10) + def test_broadcast_val_into_non_indexed_dims(self): + """Fuse when Elemwise output broadcasts into target's non-indexed dims. + + target(5, 3)[:, idx] += exp(val) — the Elemwise loop covers the + index axis (axis 1), and the scalar result broadcasts into the + non-indexed leading axis (axis 0, core_ndim=1). + + Requires excluding local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 + so we keep AdvancedIncSubtensor on the correct axis. + """ + rng = np.random.default_rng(42) + idx = rng.integers(3, size=10).astype(np.int64) + target = pt.matrix("target", shape=(5, 3)) + val = pt.vector("val", shape=(10,)) + out = pt.inc_subtensor(target[:, idx], pt.exp(val)) + + mode = NUMBA_MODE.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + mode_u = NUMBA_NO_FUSION.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + fn = pytensor.function([val, target], out, mode=mode, trust_input=True) + fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True) + assert_fused(fn) + valv = rng.normal(size=(10,)) + tv = rng.normal(size=(5, 3)) + np.testing.assert_allclose(fn(valv, tv), fn_u(valv, tv), rtol=1e-10) + def test_inc_subtensor(self): rng = np.random.default_rng(42) idx = rng.integers(85, size=919).astype(np.int64) From e286e587272e2a9463e3f6035ee343b581e180c1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Apr 2026 18:10:01 +0200 Subject: [PATCH 7/7] WIP fixes --- pytensor/compile/function/types.py | 16 ++++++ pytensor/tensor/rewriting/elemwise.py | 41 +++++++++----- pytensor/tensor/rewriting/indexed_elemwise.py | 54 ++++++++++++------- tests/benchmarks/test_rewriting.py | 2 +- tests/compile/function/test_types.py | 32 +++++++---- tests/tensor/rewriting/test_math.py | 4 +- 6 files changed, 107 insertions(+), 42 deletions(-) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 0649f3935b..cb360b8b53 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -723,6 +723,22 @@ def checkSV(sv_ori, sv_rpl): # The first len(maker.outputs) variables are original variables. # The rest are the updates. out_vars = maker.fgraph.outputs[: len(maker.outputs)] + + # Warn if any shared variable with a deleted update is being + # destroyed (mutated inplace) by a node in the graph. In that case + # the shared variable storage will still be mutated even though the + # update is not applied. + if hasattr(maker.fgraph, "destroyers"): + for in_idx in maker.fgraph.update_mapping.values(): + input_var = maker.fgraph.inputs[in_idx] + if maker.fgraph.destroyers(input_var): + warnings.warn( + f"Shared variable '{input_var.name}' will still be mutated " + "even though its update is being deleted, because an " + "operation in the graph modifies it inplace.", + UserWarning, + stacklevel=2, + ) else: out_vars = maker.fgraph.outputs diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 517819da98..59e84c56ac 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -906,9 +906,10 @@ def elemwise_scalar_op_has_c_code( # producer-consumer edges, so it misses siblings like f(x) and # g(x) that share an input without one feeding into the other. sibling_candidates: list = list(sorted_subgraphs) - for node in candidate_starting_nodes: - bf = nodes_bitflags.get(node) - if (bf is not None) and not (bf & all_subgraphs_bitset): + for node, bf in nodes_bitflags.items(): + if node not in candidate_starting_nodes: + continue + if not (bf & all_subgraphs_bitset): sibling_candidates.append( (bf, (tuple(dict.fromkeys(node.inputs)), node.outputs)) ) @@ -916,32 +917,49 @@ def elemwise_scalar_op_has_c_code( # Skip scalar constants as they get inlined later and don't # represent meaningful shared computation between siblings. input_to_sibling_idxs: dict[Variable, list[int]] = defaultdict(list) - for sibling_idx, (_, (inputs, _)) in enumerate(sibling_candidates): + for sibling_idx, (_, (inputs, outputs)) in enumerate(sibling_candidates): + out_bcast = outputs[0].type.broadcastable for inp in inputs: if isinstance(inp, TensorConstant) and inp.unique_value is not None: continue + # Only group siblings through inputs that aren't broadcasted + # As we may be dealing with sibiling of different shapes. + # This is conservative, we could check if other shared inputs + # jointly imply equal shapes (or look at static shapes if available) + if inp.type.broadcastable != out_bcast: + continue input_to_sibling_idxs[inp].append(sibling_idx) + # Track which candidate each index was merged into (union-find) + merged_into: dict[int, int] = {} + + def find_canonical(idx): + """Follow merge chain to find the live candidate.""" + try: + while True: + idx = merged_into[idx] + except KeyError: + return idx + for sibling_idxs in input_to_sibling_idxs.values(): if len(sibling_idxs) < 2: continue for i in range(len(sibling_idxs) - 1): - sibling_i = sibling_idxs[i] - if sibling_candidates[sibling_i] is None: - # Already merged in a previous iteration - continue + sibling_i = find_canonical(sibling_idxs[i]) bitset_i, (inputs_i, outputs_i) = sibling_candidates[sibling_i] bcast_i = outputs_i[0].type.broadcastable merged = False for sibling_j in sibling_idxs[i + 1 :]: - if sibling_candidates[sibling_j] is None: - # Already merged in a previous iteration + sibling_j = find_canonical(sibling_j) + if sibling_j == sibling_i: + # Already in the same group continue bitset_j, (inputs_j, outputs_j) = sibling_candidates[sibling_j] if bcast_i != outputs_j[0].type.broadcastable: continue + # Independence: neither is ancestor of the other if ( bitset_i & ancestors_bitsets[outputs_j[0].owner] @@ -958,8 +976,7 @@ def elemwise_scalar_op_has_c_code( if input_j not in merged_inputs_set: merged_inputs.append(input_j) merged_inputs_set.add(input_j) - # Hide sibling_j from future merges - sibling_candidates[sibling_j] = None + merged_into[sibling_j] = sibling_i # Update ancestor bitsets so that any node # depending on part of the merged group now diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index 580a975bd3..febc461b0c 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -15,7 +15,7 @@ from pytensor.scalar.basic import Composite, identity from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer -from pytensor.tensor.shape import Reshape +from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -250,7 +250,7 @@ class IndexedElemwise(OpFromGraph): def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): self.indexed_inputs = indexed_inputs self.indexed_outputs = indexed_outputs - super().__init__(*args, on_unused_input="ignore", **kwargs) + super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) def __str__(self): for node in self.fgraph.apply_nodes: @@ -627,16 +627,25 @@ def find_indexed_update_consumers(fgraph, node): inner_outputs = list(node.outputs) call_inputs = list(inner_inputs) scatter_transpose_back = {} # scatter_idx -> inverse perm + # Save original idx_axis_pairs before transpose may remap axes, + # so we can look up the correct key in all_idx_groups later. + original_idx_axis_pairs = { + out_idx: list(entry[2]) for out_idx, entry in update_consumers.items() + } for out_idx in sorted(update_consumers.keys()): update_node, target, idx_axis_pairs, _mode = update_consumers[out_idx] - # Check if we need to transpose the target - max_idx_axis = max(a for _, a in idx_axis_pairs) - needs_transpose = any( - a < max_idx_axis - for a in range(target.type.ndim) - if a not in {a for _, a in idx_axis_pairs} - ) + # If there are batch dims on the target (beyond the elemwise loop), + # we transpose them to the right, so they work as "core_ndim" + # in the vectorized codegen (which must always be on the right) + # (e.g. target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)), + idx_axes_set = {a for _, a in idx_axis_pairs} + max_idx_axis = max(idx_axes_set) + n_indexed_axes = len(idx_axes_set) + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) + elemwise_batch_ndim = len(node.outputs[0].type.broadcastable) + source_batch = elemwise_batch_ndim + n_indexed_axes - n_idx_dims + needs_transpose = max_idx_axis >= source_batch if needs_transpose: # Build permutation: indexed axes first, then rest idx_axes_sorted = sorted({a for _, a in idx_axis_pairs}) @@ -674,6 +683,15 @@ def find_indexed_update_consumers(fgraph, node): # The value to scatter: use the (possibly duplicated) Elemwise output scatter_value = node.outputs[scatter_idx] + # To work with IndexedElemwise codegen, we moved indexed update batch dims to the right. + # target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)) + # For this operation to be strictly valid, we need to add expand dims the scattered value. + # target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)[:, None]) + # This expand_dims is subsumed by the IndexedElemwise, but makes the inner graph valid. + core_ndim = target.type.ndim - source_batch + if needs_transpose and core_ndim > 0: + scatter_value = shape_padright(scatter_value, core_ndim) + # Build the scatter output # After transpose, indexed axes are at position 0+, # so always use AdvancedIncSubtensor1 for single axis-0. @@ -730,18 +748,18 @@ def find_indexed_update_consumers(fgraph, node): break # Build indexed_outputs spec for the Op. - # Use the (possibly remapped) axis from update_consumers. + # Use the (possibly remapped) axis from update_consumers for the + # spec value, but the ORIGINAL axis for the all_idx_groups lookup + # (since all_idx_groups was built before any transpose). indexed_outputs_spec = [None] * n_indices for out_idx in sorted(update_consumers.keys()): _update_node, _target, idx_axis_pairs, mode = update_consumers[out_idx] - for idx_var, axis in idx_axis_pairs: - # Look up in all_idx_groups using the ORIGINAL key - # (all_idx_groups was built before any transpose). - idx_pos = None - for (iv, _), (v, p) in all_idx_groups.items(): - if iv is idx_var and p is not None: - idx_pos = p - break + orig_pairs = original_idx_axis_pairs[out_idx] + for (idx_var, axis), (_orig_var, orig_axis) in zip( + idx_axis_pairs, orig_pairs + ): + key = (idx_var, orig_axis) + idx_pos = all_idx_groups.get(key, (None, None))[1] if idx_pos is None: continue scatter_idx = dup_map.get(out_idx, out_idx) diff --git a/tests/benchmarks/test_rewriting.py b/tests/benchmarks/test_rewriting.py index cea3c4e676..5b0b86ba6f 100644 --- a/tests/benchmarks/test_rewriting.py +++ b/tests/benchmarks/test_rewriting.py @@ -45,7 +45,7 @@ def _deep_small_kernels(n): "graph_fn, n, expected_n_repl", [ ("deep_small_kernels", 20, (20, 60)), - ("large_fuseable_graph", 25, (128, 876)), + ("large_fuseable_graph", 25, (55, 901)), ], ) def test_fusion_rewrite_benchmark(graph_fn, n, expected_n_repl, benchmark): diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index cb2c27fada..cdd079a2e1 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -491,6 +491,8 @@ def test_swap_SharedVariable_with_given(self): assert in1.value is in2.value def test_copy_delete_updates(self): + from contextlib import nullcontext + w = iscalar("w") x = fscalar("x") # SharedVariable for tests, one of them has update @@ -498,25 +500,35 @@ def test_copy_delete_updates(self): z = shared(value=2, name="z") out = x + y + z - # Test for different linkers - # for mode in ["FAST_RUN","FAST_COMPILE"]: - # second_time = False + inplace_warn = pytest.warns(UserWarning, match="will still be mutated") + for mode in ("FAST_RUN", "FAST_COMPILE"): + y.set_value(1) + z.set_value(2) + # FAST_RUN applies inplace rewrites, so z will be destroyed. + # A warning is issued because dropping the update doesn't prevent + # the shared variable storage from being mutated. + ctx = inplace_warn if mode == "FAST_RUN" else nullcontext() ori = function([x], out, mode=mode, updates={z: z * 2}) - cpy = ori.copy(delete_updates=True) - - assert cpy(1) == 4 - assert cpy(1) == 4 - assert cpy(1) == 4 + with ctx: + cpy = ori.copy(delete_updates=True) + if mode == "FAST_COMPILE": + assert cpy(1) == 4 + assert cpy(1) == 4 + assert cpy(1) == 4 # Test if unused implicit and explicit inputs from delete_updates # are ignored as intended. for mode in ("FAST_RUN", "FAST_COMPILE"): + ctx = inplace_warn if mode == "FAST_RUN" else nullcontext() + ori = function([x], x, mode=mode, updates={z: z * 2}) - cpy = ori.copy(delete_updates=True) + with ctx: + cpy = ori.copy(delete_updates=True) ori = function([x, w], x, mode=mode, updates={z: z + w}) - cpy = ori.copy(delete_updates=True) + with ctx: + cpy = ori.copy(delete_updates=True) def test_shared_state0(self): a = scalar() # the a is for 'anonymous' (un-named). diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 406cd06f48..3f50c32425 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4651,7 +4651,9 @@ def test_polygamma_specialization(): y3 = polygamma(2, x) fn = pytensor.function( - [x], [y1, y2, y3], mode=get_default_mode().including("specialize") + [x], + [y1, y2, y3], + mode=get_default_mode().including("specialize").excluding("fusion"), ) fn_outs = fn.maker.fgraph.outputs assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)