diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 0649f3935b..cb360b8b53 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -723,6 +723,22 @@ def checkSV(sv_ori, sv_rpl): # The first len(maker.outputs) variables are original variables. # The rest are the updates. out_vars = maker.fgraph.outputs[: len(maker.outputs)] + + # Warn if any shared variable with a deleted update is being + # destroyed (mutated inplace) by a node in the graph. In that case + # the shared variable storage will still be mutated even though the + # update is not applied. + if hasattr(maker.fgraph, "destroyers"): + for in_idx in maker.fgraph.update_mapping.values(): + input_var = maker.fgraph.inputs[in_idx] + if maker.fgraph.destroyers(input_var): + warnings.warn( + f"Shared variable '{input_var.name}' will still be mutated " + "even though its update is being deleted, because an " + "operation in the graph modifies it inplace.", + UserWarning, + stacklevel=2, + ) else: out_vars = maker.fgraph.outputs diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index f2405e3542..af4aa4aee3 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -11,6 +11,9 @@ register_funcify_and_cache_key, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, _jit_options, _vectorized, encode_literals, @@ -90,7 +93,9 @@ def impl(*inputs_and_core_shapes): (), # constant_inputs inputs, tuple_core_shapes, - None, # size + NO_SIZE, + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return impl diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2308ddffdc..73bc56a324 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -20,6 +20,10 @@ ) from pytensor.link.numba.dispatch.string_codegen import create_tuple_string from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, + _jit_options, _vectorized, encode_literals, store_core_outputs, @@ -35,13 +39,13 @@ Mul, Sub, TrueDiv, - get_scalar_type, maximum, ) from pytensor.scalar.basic import add as add_as from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -312,8 +316,7 @@ def axis_apply_fn(x): @register_funcify_and_cache_key(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] - scalar_node = op.scalar_op.make_node(*scalar_inputs) + scalar_node = op.make_scalar_node(*node.inputs) scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( op.scalar_op, node=scalar_node, @@ -368,8 +371,10 @@ def impl(*inputs): True, # allow_core_scalar (), # constant_inputs inputs, - core_output_shapes, # core_shapes - None, # size + core_output_shapes, + NO_SIZE, + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return impl @@ -390,6 +395,121 @@ def impl(*inputs): return elemwise, elemwise_key +@register_funcify_and_cache_key(IndexedElemwise) +def numba_funcify_IndexedElemwise(op, node, **kwargs): + """Generate fused Elemwise Numba code with indexed reads and updates. + + Reads indexed_inputs/indexed_outputs specs stored on the Op by the + rewriting pass, and generates a single vectorized loop with indirect + indexing. + + Outer inputs are ordered as:: + + [elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...] + """ + [elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)] + + indexed_inputs = op.indexed_inputs + indexed_outputs = op.indexed_outputs + + # Derive update_out_set and inc_outputs from stored specs + update_out_set = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None + for out_idx in entry[0] + ) + inc_outputs = frozenset( + out_idx + for entry in indexed_outputs + if entry is not None + for out_idx in entry[0] + if entry[1] == "inc" + ) + + # --- Scalar function and core op -------------------------------------- + scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs) + scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( + elemwise_node.op.scalar_op, node=scalar_node, **kwargs + ) + + nin_elemwise = len(elemwise_node.inputs) + nout = len(elemwise_node.outputs) + core_op_fn = store_core_outputs( + scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + ) + + # --- Broadcast and type encodings ------------------------------------- + # Use each input's result broadcastable (not the source's). + # For indexed reads, the result has the index dim(s) substituted. + # For multi-index, the result has fewer dims than the source. + input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) + # Use Elemwise output bc (loop dims) for ALL outputs — including updates. + # The fix-up block in _vectorized corrects the actual output/return types + # for update outputs to match the target buffer. + output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) + output_dtypes = tuple(out.type.dtype for out in node.outputs) + # Filter out elemwise inplace for indexed-update outputs (they use update targets) + inplace_pattern = tuple( + (out_idx, inp_idx) + for out_idx, inp_idx in elemwise_node.op.inplace_pattern.items() + if out_idx not in update_out_set + ) + core_output_shapes = tuple(() for _ in range(nout)) + + input_bc_patterns_enc = encode_literals(input_bc_patterns) + output_bc_patterns_enc = encode_literals(output_bc_patterns) + output_dtypes_enc = encode_literals(output_dtypes) + inplace_pattern_enc = encode_literals(inplace_pattern) + indexed_inputs_enc = encode_literals(indexed_inputs) + indexed_outputs_enc = encode_literals(indexed_outputs) + + def indexed_elemwise_fn(*outer_inputs): + raise NotImplementedError( + "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." + ) + + @overload(indexed_elemwise_fn, jit_options=_jit_options) + def ov_indexed_elemwise_fn(*outer_inputs): + def impl(*outer_inputs): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + outer_inputs, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + ) + + return impl + + cache_version = (0, 5) + if scalar_cache_key is None: + key = None + else: + key = str( + ( + type(op), + "IndexedElemwise", + cache_version, + inplace_pattern, + input_bc_patterns, + indexed_inputs, + indexed_outputs, + scalar_cache_key, + ) + ) + key = sha256(key.encode()).hexdigest() + + return indexed_elemwise_fn, key + + @register_funcify_and_cache_key(CAReduce) def numba_funcify_CAReduce(op, node, **kwargs): axes = op.axis diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 0ac6a301cc..09e69603f6 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -21,6 +21,9 @@ ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.vectorize_codegen import ( + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, + NO_SIZE, _jit_options, _vectorized, encode_literals, @@ -474,9 +477,11 @@ def impl(core_shape, rng, size, *dist_params): (rng,), dist_params, (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), - None + NO_SIZE if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len), + NO_INDEXED_INPUTS, + NO_INDEXED_OUTPUTS, ) return rng, draws diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index f804c0c04c..cf03d4ed51 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import operator import pickle from collections.abc import Callable, Sequence from textwrap import indent @@ -13,27 +14,53 @@ from numba.core import cgutils from numba.core.base import BaseContext from numba.core.types.misc import NoneType +from numba.extending import overload from numba.np import arrayobj from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +# Numba is missing getitem(array, Ellipsis), so o[...] += val fails. +# Register it so store_core_outputs can use o[...] += val naturally. +@overload(operator.getitem, inline="always") +def _getitem_0d_ellipsis(arr, idx): + if isinstance(arr, types.Array) and isinstance(idx, types.EllipsisType): + if arr.ndim == 0: + + def impl(arr, idx): + return arr[()] + + return impl + else: + + def impl(arr, idx): + return arr + + return impl + + def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() -def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: +def store_core_outputs( + core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() +) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @njit def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) - o0[...] = to0 - o1[...] = to1 + o0[...] = to0 # direct outputs + o1[...] += to1 # inc outputs (indexed update) ... - on[...] = ton + Parameters + ---------- + inc_outputs : frozenset + Output indices that use ``+=`` (increment) instead of ``=`` (assign). + Used for indexed-update outputs in fused loops. """ inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] @@ -43,8 +70,12 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): out_signature = ", ".join(outputs) inner_out_signature = ", ".join(inner_outputs) store_outputs = "\n".join( - f"{output}[...] = {inner_output}" - for output, inner_output in zip(outputs, inner_outputs, strict=True) + f"{output}[...] += {inner_output}" + if i in inc_outputs + else f"{output}[...] = {inner_output}" + for i, (output, inner_output) in enumerate( + zip(outputs, inner_outputs, strict=True) + ) ) func_src = f""" def store_core_outputs({inp_signature}, {out_signature}): @@ -74,73 +105,27 @@ def store_core_outputs({inp_signature}, {out_signature}): } -@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) -def _vectorized( - typingctx, - core_func, +def _decode_literal(val, name): + if not isinstance(val, types.Literal): + raise TypingError(f"{name} must be literal.") + return pickle.loads(base64.decodebytes(val.literal_value.encode())) + + +def _compute_vectorized_types( + input_types, input_bc_patterns, output_bc_patterns, output_dtypes, inplace_pattern, allow_core_scalar, - constant_inputs_types, - input_types, output_core_shape_types, - size_type, ): - arg_types = [ - core_func, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, - allow_core_scalar, - constant_inputs_types, - input_types, - output_core_shape_types, - size_type, - ] - - if not isinstance(input_bc_patterns, types.Literal): - raise TypingError("input_bc_patterns must be literal.") - input_bc_patterns = input_bc_patterns.literal_value - input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) - - if not isinstance(output_bc_patterns, types.Literal): - raise TypeError("output_bc_patterns must be literal.") - output_bc_patterns = output_bc_patterns.literal_value - output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) - - if not isinstance(output_dtypes, types.Literal): - raise TypeError("output_dtypes must be literal.") - output_dtypes = output_dtypes.literal_value - output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) - - if not isinstance(inplace_pattern, types.Literal): - raise TypeError("inplace_pattern must be literal.") - inplace_pattern = inplace_pattern.literal_value - inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - - if not isinstance(allow_core_scalar, types.Literal): - raise TypeError("allow_core_scalar must be literal.") - allow_core_scalar = allow_core_scalar.literal_value - + """Compute core input/output types and return type for vectorized intrinsics.""" batch_ndim = len(input_bc_patterns[0]) - nin = len(constant_inputs_types) + len(input_types) - nout = len(output_bc_patterns) - if nin == 0: - raise TypingError("Empty argument list to vectorized op.") - - if nout == 0: - raise TypingError("Empty list of outputs for vectorized op.") - - if not all(isinstance(input, types.Array) for input in input_types): + if not all(isinstance(t, types.Array) for t in input_types): raise TypingError("Vectorized inputs must be arrays.") - - if not all( - len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns - ): + if not all(len(p) == batch_ndim for p in input_bc_patterns + output_bc_patterns): raise TypingError( "Vectorized broadcastable patterns must have the same length." ) @@ -149,12 +134,15 @@ def _vectorized( for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): core_ndim = input_type.ndim - len(bc_pattern) if allow_core_scalar and core_ndim == 0: - core_input_type = input_type.dtype + core_input_types.append(input_type.dtype) else: - core_input_type = types.Array( - dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + # FIXME: inheriting layout from the full array is wrong for F-order + # inputs with batch dims — the core slice won't be F-contiguous. + core_input_types.append( + types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) ) - core_input_types.append(core_input_type) core_out_types = [ types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") @@ -172,115 +160,54 @@ def _vectorized( ) ] - for output_idx, input_idx in inplace_pattern: - output_type = input_types[input_idx] + for output_idx, inp_idx in inplace_pattern: + output_type = input_types[inp_idx] core_out_types[output_idx] = types.Array( dtype=output_type.dtype, ndim=output_type.ndim - batch_ndim, - layout=input_type.layout, + layout=output_type.layout, ) out_types[output_idx] = output_type ret_type = types.Tuple(out_types) - if len(output_dtypes) == 1: ret_type = ret_type.types[0] - sig = ret_type(*arg_types) - # So we can access the constant values in codegen... - input_bc_patterns_val = input_bc_patterns - output_bc_patterns_val = output_bc_patterns - output_dtypes_val = output_dtypes - inplace_pattern_val = inplace_pattern - input_types = input_types - size_is_none = isinstance(size_type, NoneType) + return core_input_types, core_out_types, out_types, ret_type - def codegen( - ctx, - builder, - sig, - args, - ): - [_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args - - constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) - inputs = cgutils.unpack_tuple(builder, inputs) - output_core_shapes = [ - cgutils.unpack_tuple(builder, shape) - for shape in cgutils.unpack_tuple(builder, output_core_shapes) - ] - size = None if size_is_none else cgutils.unpack_tuple(builder, size) - inputs = [ - arrayobj.make_array(ty)(ctx, builder, val) - for ty, val in zip(input_types, inputs, strict=True) - ] - in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] - - iter_shape = compute_itershape( - ctx, - builder, - in_shapes, - input_bc_patterns_val, - size, - ) +def _codegen_return_outputs( + ctx, builder, sig, outputs, inplace_pattern, extra_incref=frozenset() +): + """Generate LLVM IR to return output arrays, handling incref for inplace. - outputs, output_types = make_outputs( - ctx, - builder, - iter_shape, - output_bc_patterns_val, - output_dtypes_val, - inplace_pattern_val, - inputs, - input_types, - output_core_shapes, - ) + Parameters + ---------- + extra_incref : frozenset + Additional output indices (e.g. indexed-update outputs) that alias an input + buffer and need an incref before returning, beyond inplace outputs. + """ + incref_set = set(dict(inplace_pattern).keys()) | set(extra_incref) - core_signature = typingctx.resolve_function_type( - core_func, - [ - *constant_inputs_types, - *core_input_types, - *core_out_types, - ], - {}, - ) + if len(outputs) == 1: + if incref_set: + ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) + return outputs[0]._getvalue() - make_loop_call( - typingctx, - ctx, + for idx in sorted(incref_set): + ctx.nrt.incref( builder, - core_func, - core_signature, - iter_shape, - constant_inputs, - inputs, - outputs, - input_bc_patterns_val, - output_bc_patterns_val, - input_types, - output_types, - core_scalar=allow_core_scalar, + sig.return_type.types[idx], + outputs[idx]._getvalue(), ) + return ctx.make_tuple( + builder, sig.return_type, [out._getvalue() for out in outputs] + ) - if len(outputs) == 1: - if inplace_pattern: - assert inplace_pattern[0][0] == 0 - ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) - return outputs[0]._getvalue() - for inplace_idx in dict(inplace_pattern): - ctx.nrt.incref( - builder, - sig.return_type.types[inplace_idx], - outputs[inplace_idx]._getvalue(), - ) - return ctx.make_tuple( - builder, sig.return_type, [out._getvalue() for out in outputs] - ) - - return sig, codegen +NO_INDEXED_INPUTS = encode_literals(()) +NO_INDEXED_OUTPUTS = encode_literals(()) +NO_SIZE = None def compute_itershape( @@ -385,7 +312,16 @@ def make_outputs( inputs: tuple[Any, ...], input_types: tuple[Any, ...], output_core_shapes: tuple, + update_outputs: dict | None = None, ) -> tuple[list[ir.Value], list[types.Array]]: + """Allocate output arrays for vectorized loop. + + Parameters + ---------- + update_outputs : dict, optional + Mapping ``{output_idx: (array, array_type)}`` for outputs that reuse + a scatter-target input buffer instead of being freshly allocated. + """ output_arrays = [] output_arry_types = [] one = ir.IntType(64)(1) @@ -393,6 +329,10 @@ def make_outputs( for i, (core_shape, bc, dtype) in enumerate( zip(output_core_shapes, out_bc, dtypes, strict=True) ): + if update_outputs is not None and i in update_outputs: + output_arrays.append(update_outputs[i][0]) + output_arry_types.append(update_outputs[i][1]) + continue if i in inplace_dict: output_arrays.append(inputs[inplace_dict[i]]) output_arry_types.append(input_types[inplace_dict[i]]) @@ -412,9 +352,9 @@ def make_outputs( array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) output_arrays.append(array) - # If there is no inplace operation, we know that all output arrays - # don't alias. Informing llvm can make it easier to vectorize. - if not inplace: + # If there is no inplace or scatter operation, we know that all output + # arrays don't alias. Informing llvm can make it easier to vectorize. + if not inplace and not update_outputs: # The first argument is the output pointer arg = builder.function.args[0] arg.add_attribute("noalias") @@ -436,6 +376,12 @@ def make_loop_call( input_types: tuple[Any, ...], output_types: tuple[Any, ...], core_scalar: bool = True, + input_read_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, + idx_arrs: list | None = None, + idx_types: tuple | None = None, + idx_load_axes: tuple[tuple[int, ...], ...] | None = None, + idx_bc: tuple[tuple[bool, ...], ...] | None = None, + output_update_spec: tuple[tuple[tuple[int, int], ...] | None, ...] | None = None, ): safe = (False, False) @@ -453,6 +399,18 @@ def make_loop_call( zero = ir.Constant(ir.IntType(64), 0) + def _wrap_negative_index(idx_val, dim_size, signed): + """Wrap a negative index by adding the dimension size: idx + size if idx < 0. + + Only emits the branch for signed index dtypes; unsigned indices are + returned as-is since they cannot be negative. + """ + if not signed: + return idx_val + is_neg = builder.icmp_signed("<", idx_val, zero) + wrapped = builder.add(idx_val, dim_size) + return builder.select(is_neg, wrapped, idx_val) + # Setup loops and initialize accumulators for outputs # This part corresponds to opening the loops loop_stack = [] @@ -475,14 +433,93 @@ def make_loop_call( # Code in the inner most loop... idxs = [loopval.index for loopval in loops] + # Load indirect indices for all index arrays. + # Each index array may be ND (e.g. a 2D matrix index), accessed by + # multiple loop counters corresponding to its load axes. + indirect_idxs = [] + if idx_arrs is not None and idx_types is not None and idx_load_axes is not None: + for gi_k, (gi_arr, gi_type, load_axes) in enumerate( + zip(idx_arrs, idx_types, idx_load_axes) + ): + load_idxs = [] + for d, ax in enumerate(load_axes): + is_bc = idx_bc[gi_k][d] if idx_bc and len(idx_bc[gi_k]) > d else False + load_idxs.append(zero if is_bc else idxs[ax]) + gi_ptr = cgutils.get_item_pointer2( + context, + builder, + gi_arr.data, + cgutils.unpack_tuple(builder, gi_arr.shape), + cgutils.unpack_tuple(builder, gi_arr.strides), + gi_type.layout, + load_idxs, + False, + False, + ) + val = builder.load(gi_ptr) + # Extend to i64 to match stride types in get_item_pointer2. + i64 = ir.IntType(64) + if val.type != i64: + if gi_type.dtype.signed: + val = builder.sext(val, i64) + else: + val = builder.zext(val, i64) + # Negative indices are wrapped at point of use (see + # _wrap_negative_index) since the dimension size depends on the + # source/target array being indexed. + indirect_idxs.append(val) + # Load values from input arrays input_vals = [] - for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True): - core_ndim = input_type.ndim - len(bc) - - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ - zero - ] * core_ndim + for input_i, (input, input_type, bc) in enumerate( + zip(inputs, input_types, input_bc, strict=True) + ): + spec = input_read_spec[input_i] if input_read_spec is not None else None + n_indexed = len(spec) if spec else 0 + # n_indexed source axes are replaced by the index arrays' broadcast + # loop dims. n_index_loop_dims = max ndim of the index arrays in the + # group (1 for 1D vectors, 2 for 2D matrices, etc.). + n_index_loop_dims = ( + max((idx_types[idx_k].ndim for idx_k, _ in spec), default=0) if spec else 0 # type: ignore[index] + ) + core_ndim = input_type.ndim - len(bc) - n_indexed + n_index_loop_dims + + if spec is not None: + assert idx_types is not None + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + input_shape = cgutils.unpack_tuple(builder, input.shape) + idxs_bc = [] + # Result dims correspond to: [indexed_loop_dim(s), then non-indexed source axes]. + # For 1D multi-index on consecutive axes 0..n-1: + # result dim 0 = indexed loop dim + # result dim 1 = source dim n, etc. + # For ND single-index on axis A (e.g. 2D mat_idx): + # result dims A..A+D-1 = index loop dims + # remaining = non-indexed source axes + result_dim = 0 + for src_dim in range(input_type.ndim): + if src_dim in indexed_axes: + idx_k = indexed_axes[src_dim] + idx_val = _wrap_negative_index( + indirect_idxs[idx_k], + input_shape[src_dim], + signed=idx_types[idx_k].dtype.signed, + ) + idxs_bc.append(idx_val) + if n_indexed == 1: + result_dim += n_index_loop_dims + elif src_dim == max(indexed_axes): + result_dim += n_index_loop_dims + else: + if result_dim < len(bc): + idxs_bc.append(zero if bc[result_dim] else idxs[result_dim]) + else: + idxs_bc.append(zero) + result_dim += 1 + else: + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) + ] + [zero] * core_ndim ptr = cgutils.get_item_pointer2( context, builder, @@ -529,15 +566,51 @@ def make_loop_call( # Create output slices to pass to inner func output_slices = [] - for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True): + for output_i, (output, output_type, bc) in enumerate( + zip(outputs, output_types, output_bc, strict=True) + ): core_ndim = output_type.ndim - len(bc) size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue] output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue] output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue] - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ - zero - ] * core_ndim + spec = output_update_spec[output_i] if output_update_spec is not None else None + if spec is not None: + assert idx_types is not None + # Indexed-update output: same logic as indexed-read input. + # Recompute core_ndim from the target's actual dims since + # output_bc may not match (it's the Elemwise output bc). + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + n_indexed = len(indexed_axes) + n_index_loop_dims = max( + (idx_types[idx_k].ndim for idx_k, _ in spec), default=0 + ) + # Number of source (target) batch dims + source_batch_ndim = len(bc) + n_indexed - n_index_loop_dims + core_ndim = output_type.ndim - source_batch_ndim + max_indexed_axis = max(indexed_axes) + idxs_bc = [] + loop_dim = 0 + for src_dim in range(source_batch_ndim): + if src_dim in indexed_axes: + idx_k = indexed_axes[src_dim] + idx_val = _wrap_negative_index( + indirect_idxs[idx_k], + output_shape[src_dim], + signed=idx_types[idx_k].dtype.signed, + ) + idxs_bc.append(idx_val) + if src_dim >= max_indexed_axis: + loop_dim += n_index_loop_dims + else: + bc_dim = bc[loop_dim] if loop_dim < len(bc) else False + idxs_bc.append(zero if bc_dim else idxs[loop_dim]) + loop_dim += 1 + idxs_bc += [zero] * core_ndim + else: + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) + ] + [zero] * core_ndim ptr = cgutils.get_item_pointer2( context, builder, @@ -548,7 +621,6 @@ def make_loop_call( idxs_bc, *safe, ) - # Retrieve array item at index # This is a streamlined version of Numba's `GUArrayArg.load` core_arry_type = types.Array( @@ -583,3 +655,459 @@ def make_loop_call( loop.__exit__(None, None, None) return + + +@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) +def _vectorized( + typingctx, + core_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + constant_inputs_types, + outer_input_types, + output_core_shape_types, + size_type, + indexed_inputs, + indexed_outputs, +): + """Like _vectorized but with indirect indexing for reads and updates. + + Outer inputs are ordered as + ``[elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...]``. + + ``indexed_inputs`` groups elemwise input positions by which index they + read through: e.g. ``((0, 2), (1,))`` means idx_0 reads inputs 0 and 2, + idx_1 reads input 1. Entries may be empty ``()`` for update-only indices. + + ``indexed_outputs`` has one entry per index array (same length as + ``indexed_inputs``). ``None`` means that index is not used for updates. + ``((out_0, out_1), mode)`` means that index updates outputs out_0 and + out_1 with *mode* ``"set"`` or ``"inc"``. + + Parameters + ---------- + outer_input_types : tuple of Array types + ``(elemwise_input_0, ..., elemwise_input_N, idx_0, ..., update_target_0, ...)`` + indexed_inputs : literal str + Encoded ``tuple[tuple[int, ...], ...]``. + indexed_outputs : literal str + Encoded ``tuple[tuple[tuple[int, ...], str] | None, ...]``. + """ + arg_types = [ + core_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + constant_inputs_types, + outer_input_types, + output_core_shape_types, + size_type, + indexed_inputs, + indexed_outputs, + ] + + input_bc_patterns = _decode_literal(input_bc_patterns, "input_bc_patterns") + output_bc_patterns = _decode_literal(output_bc_patterns, "output_bc_patterns") + output_dtypes = _decode_literal(output_dtypes, "output_dtypes") + inplace_pattern = _decode_literal(inplace_pattern, "inplace_pattern") + indexed_inputs = _decode_literal(indexed_inputs, "indexed_inputs") + indexed_outputs = _decode_literal(indexed_outputs, "indexed_outputs") + + if not isinstance(allow_core_scalar, types.Literal): + raise TypingError("allow_core_scalar must be literal.") + allow_core_scalar = allow_core_scalar.literal_value + + # Count scatter targets (one per unique output index) + _update_out_idxs = set() + for entry in indexed_outputs: + if entry is not None: + _update_out_idxs.update(entry[0]) + n_update_targets = len(_update_out_idxs) + + n_indices = len(indexed_inputs) + n_elemwise = len(outer_input_types) - n_indices - n_update_targets + source_input_types = tuple(outer_input_types[i] for i in range(n_elemwise)) + idx_types = tuple(outer_input_types[n_elemwise + k] for k in range(n_indices)) + update_target_types = tuple( + outer_input_types[n_elemwise + n_indices + j] for j in range(n_update_targets) + ) + + # indexed_inputs entries are (positions, axis) — one per index array. + # For multi-index (e.g. x[idx_row, idx_col]), an input appears in multiple + # entries with different axes. We aggregate per-input into a tuple of + # (idx_k, src_axis) pairs. + # + # idx_load_axes[k] = tuple of loop dims used to load index array k. + # For a 1-D index on axis A this is (A,); for a 2-D index it is (A, A+1), etc. + # For multi-index on consecutive axes, the group's min_axis is the start. + _read_spec_dict: dict[int, list[tuple[int, int]]] = {} + idx_load_axes = [] + idx_bc_list = [] # per index array: broadcastable tuple + for k, entry in enumerate(indexed_inputs): + positions, axis = entry[0], entry[1] + idx_bc = entry[2] if len(entry) > 2 else (False,) + idx_bc_list.append(idx_bc) + for p in positions: + _read_spec_dict.setdefault(p, []).append((k, axis)) + # Build write-side grouping: index arrays that update the same output + # share a group and should use the same min_axis. + _write_group: dict[int, list[tuple[int, int]]] = {} # out_idx -> [(k, axis)] + for k, entry in enumerate(indexed_outputs): + if entry is None: + continue + output_indices, _mode, axis = entry + for out_idx in output_indices: + _write_group.setdefault(out_idx, []).append((k, axis)) + + # idx_load_axes[k] = tuple of loop dims for loading index array k. + for k, entry in enumerate(indexed_inputs): + _positions, axis = entry[0], entry[1] + idx_ndim = idx_types[k].ndim + # Find the group's min_axis from reads and writes + min_axis = axis + for p in _positions: + if p in _read_spec_dict: + for _other_k, other_axis in _read_spec_dict[p]: + min_axis = min(min_axis, other_axis) + for out_idx, group in _write_group.items(): + if any(gk == k for gk, _ in group): + for _other_k, other_axis in group: + min_axis = min(min_axis, other_axis) + idx_load_axes.append(tuple(range(min_axis, min_axis + idx_ndim))) + input_read_spec = tuple( + tuple(_read_spec_dict[p]) if p in _read_spec_dict else None + for p in range(n_elemwise) + ) + idx_load_axes = tuple(idx_load_axes) + + # Build effective input types that match input_bc_patterns ndim. + # For indexed inputs, the source ndim differs from the result ndim: + # - Multi-index (N 1-D indices on N axes): collapses N source axes into 1 loop dim + # - ND index (1 index with ndim D on 1 axis): expands 1 source axis into D loop dims + input_types = [] + for p, src_type in enumerate(source_input_types): + spec = input_read_spec[p] + if spec is not None: + n_indexed_axes = len(spec) + n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec) + if n_indexed_axes != n_index_loop_dims: + effective_ndim = src_type.ndim - n_indexed_axes + n_index_loop_dims + input_types.append( + types.Array(src_type.dtype, effective_ndim, src_type.layout) + ) + else: + input_types.append(src_type) + else: + input_types.append(src_type) + input_types = tuple(input_types) + + # Per-output: tuple of (idx_k, axis) pairs, or None. + # Same format as input_read_spec. + # indexed_outputs entries are (positions, mode, axis) or None. + _update_spec_dict: dict[int, list[tuple[int, int]]] = {} + update_out_to_target = {} + update_out_indices = set() + target_counter = 0 + for k, entry in enumerate(indexed_outputs): + if entry is None: + continue + output_indices, _mode, axis = entry + for out_idx in output_indices: + _update_spec_dict.setdefault(out_idx, []).append((k, axis)) + if out_idx not in update_out_to_target: + update_out_to_target[out_idx] = target_counter + target_counter += 1 + update_out_indices.add(out_idx) + output_update_spec = tuple( + tuple(_update_spec_dict[p]) if p in _update_spec_dict else None + for p in range(len(output_bc_patterns)) + ) + update_out_indices = frozenset(update_out_indices) + + core_input_types, core_out_types, out_types, ret_type = _compute_vectorized_types( + input_types, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + allow_core_scalar, + output_core_shape_types, + ) + + # Fix up output types for scattered outputs: they match the target buffer + if update_out_to_target: + core_out_types = list(core_out_types) + out_types = list(out_types) + batch_ndim = len(input_bc_patterns[0]) + for out_idx, target_idx in update_out_to_target.items(): + target_type = update_target_types[target_idx] + out_types[out_idx] = target_type + # Core ndim = target dims minus the dims addressed by the loop. + # For multi-index or ND, the indexed axes are replaced by loop dims + # via indirect indexing, so the "batch" portion of the target is + # the number of source dims addressed by indirect + loop counters. + spec = output_update_spec[out_idx] + if spec is not None: + n_indexed = len(spec) + n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec) + effective_batch = batch_ndim + n_indexed - n_index_loop_dims + else: + effective_batch = batch_ndim + core_out_types[out_idx] = types.Array( + dtype=target_type.dtype, + ndim=target_type.ndim - effective_batch, + layout=target_type.layout, + ) + out_types = tuple(out_types) + core_out_types = tuple(core_out_types) + + if len(out_types) == 1: + ret_type = out_types[0] + else: + ret_type = types.Tuple(out_types) + + sig = ret_type(*arg_types) + + size_is_none = isinstance(size_type, NoneType) + + # Save values for codegen closure + input_bc_patterns_val = input_bc_patterns + output_bc_patterns_val = output_bc_patterns + output_dtypes_val = output_dtypes + inplace_pattern_val = inplace_pattern + input_read_spec_val = input_read_spec + idx_types_val = idx_types + idx_load_axes_val = idx_load_axes + idx_bc_list_val = idx_bc_list + output_update_spec_val = output_update_spec + update_out_to_target_val = update_out_to_target + update_target_types_val = update_target_types + update_out_indices_val = update_out_indices + indexed_outputs_val = indexed_outputs + + def codegen(ctx, builder, sig, args): + [ + _, + _, + _, + _, + _, + _, + constant_inputs, + outer_inputs, + output_core_shapes, + size, + _, + _, + ] = args + + constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) + all_outer = cgutils.unpack_tuple(builder, outer_inputs) + output_core_shapes = [ + cgutils.unpack_tuple(builder, shape) + for shape in cgutils.unpack_tuple(builder, output_core_shapes) + ] + size = None if size_is_none else cgutils.unpack_tuple(builder, size) + + # First n_elemwise outer inputs are elemwise inputs (source arrays) + inputs = [ + arrayobj.make_array(source_input_types[i])(ctx, builder, all_outer[i]) + for i in range(n_elemwise) + ] + in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] + + # Next n_indices inputs are index arrays + idx_arrs = [ + arrayobj.make_array(idx_types_val[k])( + ctx, builder, all_outer[n_elemwise + k] + ) + for k in range(n_indices) + ] + + # Remaining inputs are scatter target buffers + update_target_arrs = [ + arrayobj.make_array(update_target_types_val[j])( + ctx, builder, all_outer[n_elemwise + n_indices + j] + ) + for j in range(n_update_targets) + ] + + # Build iter_shapes for compute_itershape. + # For indexed inputs, the source array may have more dims than the + # iteration shape (multi-index collapses multiple source axes into one + # loop dim). Replace the source shape with a constructed shape that + # matches the bc pattern: one entry per loop dim, with index lengths + # substituted for the indexed loop dim(s). + one = ir.IntType(64)(1) + iter_shapes = list(in_shapes) + iter_bc = list(input_bc_patterns_val) + idx_shapes = [ + cgutils.unpack_tuple(builder, idx_arrs[k].shape) for k in range(n_indices) + ] + for p, spec in enumerate(input_read_spec_val): + if spec is None: + continue + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + n_indexed = len(indexed_axes) + n_index_loop_dims = max(idx_types_val[idx_k].ndim for idx_k, _ in spec) + if n_indexed == n_index_loop_dims: + # Simple case (1 index on 1 axis, or N 1-D indices on N axes): + # substitute each indexed axis with the index's shape dim. + idx_k, axis = spec[0] + iter_shapes[p] = list(iter_shapes[p]) + iter_shapes[p][axis] = idx_shapes[idx_k][0] + else: + # Mismatch: ND index on fewer axes or multi-index collapsing. + # Build shape mapping result loop dims to source dims or index dims. + # Indexed source axes expand to n_index_loop_dims result dims. + source_shape = cgutils.unpack_tuple(builder, inputs[p].shape) + batch_ndim = len(input_bc_patterns_val[p]) + indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} + max_axis = max(a for _, a in spec) + new_shape = [] + new_bc = [] + src_d = 0 + idx_d = 0 + for loop_d in range(batch_ndim): + if src_d in indexed_axes and idx_d < n_index_loop_dims: + # Placeholder — actual index shapes are contributed + # separately by each index array's iter_shape entry. + new_shape.append(one) + new_bc.append(True) + idx_d += 1 + if idx_d >= n_index_loop_dims: + src_d = max_axis + 1 + else: + new_shape.append(source_shape[src_d]) + new_bc.append(iter_bc[p][loop_d]) + src_d += 1 + iter_shapes[p] = new_shape + iter_bc[p] = tuple(new_bc) + + # Each index array participates in iter_shape validation. + # + # Write indices can broadcast against each other (e.g. ir=(3,1) + # and ic=(1,4) → (3,4)), so we honour their static bc. But if + # ALL write indices on a given loop dim are bc=True, none of them + # constrains the loop size and a shape-1 index would silently + # repeat writes. In that case we force bc=False so + # compute_itershape requires the index length to match the loop. + batch_ndim = len(input_bc_patterns_val[0]) if input_bc_patterns_val else 0 + + # Per loop dim: is every write index broadcastable? + _write_all_bc = [True] * batch_ndim + for k in range(n_indices): + if indexed_outputs_val[k] is None: + continue + load_axes = idx_load_axes_val[k] + for d, ax in enumerate(load_axes): + if ax < batch_ndim: + idx_bc_on_d = ( + idx_bc_list_val[k][d] if d < len(idx_bc_list_val[k]) else False + ) + if not idx_bc_on_d: + _write_all_bc[ax] = False + + for k in range(n_indices): + load_axes = idx_load_axes_val[k] + is_write = indexed_outputs_val[k] is not None + idx_shape_entry = [one] * batch_ndim + bc_entry = [True] * batch_ndim + for d, ax in enumerate(load_axes): + if ax < batch_ndim and d < len(idx_shapes[k]): + idx_shape_entry[ax] = idx_shapes[k][d] + idx_bc_on_d = ( + idx_bc_list_val[k][d] if d < len(idx_bc_list_val[k]) else False + ) + # Force non-bc if this is a write index and all write + # indices on this dim are bc — otherwise the loop dim + # is unconstrained by any write index. + if is_write and idx_bc_on_d and _write_all_bc[ax]: + idx_bc_on_d = False + if ax < batch_ndim: + bc_entry[ax] = idx_bc_on_d + iter_shapes.append(idx_shape_entry) + iter_bc.append(tuple(bc_entry)) + + iter_shape = compute_itershape( + ctx, + builder, + iter_shapes, + tuple(iter_bc), + size, + ) + + # Build update_outputs dict for make_outputs: out_idx -> (array, type) + update_outputs_dict = ( + { + out_idx: ( + update_target_arrs[target_idx], + update_target_types_val[target_idx], + ) + for out_idx, target_idx in update_out_to_target_val.items() + } + if update_out_to_target_val + else None + ) + + outputs, output_types = make_outputs( + ctx, + builder, + iter_shape, + output_bc_patterns_val, + output_dtypes_val, + inplace_pattern_val, + inputs, + source_input_types, + output_core_shapes, + update_outputs=update_outputs_dict, + ) + + core_signature = typingctx.resolve_function_type( + core_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, + ) + + make_loop_call( + typingctx, + ctx, + builder, + core_func, + core_signature, + iter_shape, + constant_inputs, + inputs, + outputs, + input_bc_patterns_val, + output_bc_patterns_val, + source_input_types, + output_types, + core_scalar=allow_core_scalar, + input_read_spec=input_read_spec_val, + idx_arrs=idx_arrs, + idx_types=idx_types_val, + idx_load_axes=idx_load_axes_val, + idx_bc=idx_bc_list_val, + output_update_spec=output_update_spec_val, + ) + + return _codegen_return_outputs( + ctx, + builder, + sig, + outputs, + inplace_pattern, + extra_incref=update_out_indices_val, + ) + + return sig, codegen diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index ddb7376fce..a17c5e70ce 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -381,14 +381,22 @@ def __setstate__(self, d): self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) + def make_scalar_node(self, *inputs): + """Create a scalar Apply node matching the dtypes of tensor inputs. + + Used by get_output_info, grad, and backend dispatchers to obtain + the scalar-level graph corresponding to this Elemwise operation. + """ + return self.scalar_op.make_node( + *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] + ) + def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. """ - shadow = self.scalar_op.make_node( - *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] - ) + shadow = self.make_scalar_node(*inputs) target_length = max(input.type.ndim for input in inputs) @@ -545,9 +553,7 @@ def as_scalar(t): scalar_inputs = list(map(as_scalar, inputs)) scalar_ograds = list(map(as_scalar, ograds)) - scalar_outputs = self.scalar_op.make_node( - *[get_scalar_type(dtype=i.type.dtype).make_variable() for i in inputs] - ).outputs + scalar_outputs = self.make_scalar_node(*inputs).outputs scalar_igrads = self.scalar_op.L_op( scalar_inputs, scalar_outputs, scalar_ograds ) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index b404dba4ee..59e84c56ac 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,6 +2,7 @@ import itertools import operator import sys +from collections import defaultdict from collections.abc import Generator, Sequence from functools import cache, reduce from heapq import heapify, heappop, heappush @@ -74,6 +75,91 @@ def create_inplace_node( ) -> Apply: pass + def _get_protected_inputs(self, fgraph): + """Collect inputs protected from in-place destruction.""" + protected = set( + itertools.chain.from_iterable( + f.protected for f in fgraph._features if isinstance(f, Supervisor) + ) + ) + protected.update(fgraph.outputs) + return protected + + def try_inplace_on_node( + self, + fgraph, + node, + candidate_pairs=None, + reason="inplace_optimizer", + extra_protected_inputs=frozenset(), + ): + """Try to make a single node operate in-place. + + First attempts all candidate pairs at once, then falls back to + one-at-a-time. Returns the (possibly replaced) node. + + Parameters + ---------- + candidate_pairs + Pre-filtered and sorted list of ``((out_idx, out), (in_idx, inp))`` + pairs. If None, computed from ``filter_candidate_pairs`` using + the standard protections plus ``extra_protected_inputs``. + extra_protected_inputs + Additional variables that must not be used as in-place targets, + only used when ``candidate_pairs`` is None. + """ + if candidate_pairs is None: + protected_inputs = self._get_protected_inputs(fgraph) + protected_inputs.update(extra_protected_inputs) + candidate_pairs = self.filter_candidate_pairs( + fgraph, node, protected_inputs + ) + if not candidate_pairs: + return node + + # Try in-placing all outputs at once + tried_inputs = set() + inplace_pattern = {} + for (o, _), (i, _) in candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map == inplace_pattern: + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + except InconsistencyError: + pass + else: + copy_stack_trace(node.outputs, inplace_node.outputs) + return inplace_node + + # Fall back to one output/input at a time + tried_inputs = set() + inplace_pattern = {} + original_node = node + for (o, _), (i, _) in candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map != inplace_pattern: + # This Op can't respect this partial inplace pattern, + # We assume it can't support any other cases + break + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + node = inplace_node + except InconsistencyError: + inplace_pattern.pop(o) + if node is not original_node: + copy_stack_trace(original_node.outputs, node.outputs) + return node + def apply(self, fgraph): r""" @@ -127,12 +213,7 @@ def apply(self, fgraph): } large_graph = len(fgraph.apply_nodes) > 500 - protected_inputs = set( - itertools.chain.from_iterable( - f.protected for f in fgraph._features if isinstance(f, Supervisor) - ) - ) - protected_inputs.update(fgraph.outputs) + protected_inputs = self._get_protected_inputs(fgraph) root_destroyer = fgraph.destroy_handler.root_destroyer self_op = self.op @@ -163,14 +244,15 @@ def apply(self, fgraph): if not candidate_pairs: continue + # If the fgraph has updates, we try to prioritize in-placing + # on the pairs that correspond to the update sorted_candidate_pairs = candidate_pairs if op_updates and (node_updates := set(node.outputs) & set_op_updates): - # If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update direct_update_pairs = [] indirect_update_pairs = [] other_update_pairs = [] for pair in candidate_pairs: - ((o, out), (i, inp)) = pair + ((_o, out), (_i, inp)) = pair if out in node_updates: direct_update_inp = op_updates[out] if direct_update_inp is inp: @@ -190,54 +272,11 @@ def apply(self, fgraph): direct_update_pairs + indirect_update_pairs + other_update_pairs ) - # Try in-placing all outputs at once - tried_inputs = set() - inplace_pattern = {} - for (o, _), (i, _) in sorted_candidate_pairs: - if o not in inplace_pattern and i not in tried_inputs: - inplace_pattern[o] = [i] - tried_inputs.add(i) - - inplace_node = self.create_inplace_node(node, inplace_pattern) - if inplace_node.op.destroy_map == inplace_pattern: - replacements = tuple(zip(node.outputs, inplace_node.outputs)) - try: - fgraph.replace_all_validate(replacements, reason=reason) - except InconsistencyError: - prof["nb_eager_inconsistent"] += 1 - else: - prof["nb_replaced"] += 1 - copy_stack_trace(node.outputs, inplace_node.outputs) - continue - - # If it fails or doesn't match the desired inplace pattern, try one output/input at a time - tried_inputs = set() - inplace_pattern = {} - replaced = False - original_node = node - for (o, _), (i, _) in sorted_candidate_pairs: - if o not in inplace_pattern and i not in tried_inputs: - inplace_pattern[o] = [i] - tried_inputs.add(i) - - inplace_node = self.create_inplace_node(node, inplace_pattern) - if inplace_node.op.destroy_map != inplace_pattern: - # This Op can't respect this partial inplace pattern, - # We assume it can't support any other cases - break - else: - replacements = tuple(zip(node.outputs, inplace_node.outputs)) - try: - fgraph.replace_all_validate(replacements, reason=reason) - node = inplace_node - replaced = True - except InconsistencyError: - prof["nb_inconsistent"] += 1 - # The input, not the output caused inconsistencies - inplace_pattern.pop(o) - if replaced: - copy_stack_trace(original_node.outputs, node.outputs) - prof["nb_replaced"] += replaced + result = self.try_inplace_on_node( + fgraph, node, sorted_candidate_pairs, reason=reason + ) + if result is not node: + prof["nb_replaced"] += 1 return prof @@ -555,6 +594,26 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time + def _insert_sorted_subgraph( + sorted_subgraphs, all_subgraphs_bitset, ancestors_bitset, entry + ): + """Insert a subgraph entry into sorted_subgraphs respecting topological order.""" + if not (ancestors_bitset & all_subgraphs_bitset): + # No dependency on previous subgraphs, append at the end + sorted_subgraphs.append(entry) + else: + # Find the right insertion point by iterating in topological + # order (reverse of stored order), cumulatively excluding each + # subgraph until the dependency check passes. + remaining_bitset = all_subgraphs_bitset + for index, (other_bitset, _) in enumerate(reversed(sorted_subgraphs)): + remaining_bitset &= ~other_bitset + if not (ancestors_bitset & remaining_bitset): + break + else: + raise RuntimeError("Failed to find insertion point for subgraph") + sorted_subgraphs.insert(-(index + 1), entry) + def find_fuseable_subgraphs( fg: FunctionGraph, ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: @@ -832,43 +891,153 @@ def elemwise_scalar_op_has_c_code( if node_ancestors_bitset & subgraph_bitset ) - # Add new subgraph to sorted_subgraphs - # Because we start from sink nodes in reverse topological order, most times new subgraphs - # don't depend on previous subgraphs, so we can just append them at the end. - if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): - # That's the case here - # None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs - sorted_subgraphs.append( - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) + _insert_sorted_subgraph( + sorted_subgraphs, + all_subgraphs_bitset, + unfuseable_ancestors_bitset, + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), + ) + + all_subgraphs_bitset |= subgraph_bitset + + # Merge sibling groups: independent subgraphs or remaining + # candidate nodes that share inputs but have no producer-consumer + # edge between them. The eager expansion above only walks + # producer-consumer edges, so it misses siblings like f(x) and + # g(x) that share an input without one feeding into the other. + sibling_candidates: list = list(sorted_subgraphs) + for node, bf in nodes_bitflags.items(): + if node not in candidate_starting_nodes: + continue + if not (bf & all_subgraphs_bitset): + sibling_candidates.append( + (bf, (tuple(dict.fromkeys(node.inputs)), node.outputs)) ) - else: - # But not here, so we need to find the right position for insertion. - # We iterate through the previous subgraphs in topological order (reverse of the stored order). - # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes. - remaining_subgraphs_bitset = all_subgraphs_bitset - for index, (other_subgraph_bitset, _) in enumerate( - reversed(sorted_subgraphs) - ): - # Exclude subgraph bitset - remaining_subgraphs_bitset &= ~other_subgraph_bitset - if not ( - unfuseable_ancestors_bitset & remaining_subgraphs_bitset + # Create a mapping from inputs to sibling groups that consume them. + # Skip scalar constants as they get inlined later and don't + # represent meaningful shared computation between siblings. + input_to_sibling_idxs: dict[Variable, list[int]] = defaultdict(list) + for sibling_idx, (_, (inputs, outputs)) in enumerate(sibling_candidates): + out_bcast = outputs[0].type.broadcastable + for inp in inputs: + if isinstance(inp, TensorConstant) and inp.unique_value is not None: + continue + # Only group siblings through inputs that aren't broadcasted + # As we may be dealing with sibiling of different shapes. + # This is conservative, we could check if other shared inputs + # jointly imply equal shapes (or look at static shapes if available) + if inp.type.broadcastable != out_bcast: + continue + input_to_sibling_idxs[inp].append(sibling_idx) + + # Track which candidate each index was merged into (union-find) + merged_into: dict[int, int] = {} + + def find_canonical(idx): + """Follow merge chain to find the live candidate.""" + try: + while True: + idx = merged_into[idx] + except KeyError: + return idx + + for sibling_idxs in input_to_sibling_idxs.values(): + if len(sibling_idxs) < 2: + continue + + for i in range(len(sibling_idxs) - 1): + sibling_i = find_canonical(sibling_idxs[i]) + + bitset_i, (inputs_i, outputs_i) = sibling_candidates[sibling_i] + bcast_i = outputs_i[0].type.broadcastable + merged = False + for sibling_j in sibling_idxs[i + 1 :]: + sibling_j = find_canonical(sibling_j) + if sibling_j == sibling_i: + # Already in the same group + continue + bitset_j, (inputs_j, outputs_j) = sibling_candidates[sibling_j] + if bcast_i != outputs_j[0].type.broadcastable: + continue + + # Independence: neither is ancestor of the other + if ( + bitset_i & ancestors_bitsets[outputs_j[0].owner] + or bitset_j & ancestors_bitsets[outputs_i[0].owner] ): - break # bingo - else: # no-break - raise RuntimeError( - "Failed to find insertion point for fused subgraph" + continue + + # Merge sibling_j into sibling_i + merged_bitset = bitset_i | bitset_j + merged_outputs = (*outputs_i, *outputs_j) + merged_inputs = list(inputs_i) + merged_inputs_set = set(inputs_i) + for input_j in inputs_j: + if input_j not in merged_inputs_set: + merged_inputs.append(input_j) + merged_inputs_set.add(input_j) + merged_into[sibling_j] = sibling_i + + # Update ancestor bitsets so that any node + # depending on part of the merged group now + # depends on all of it (mirrors the main loop). + merged_ancestors = reduce( + or_, + (ancestors_bitsets[o.owner] for o in merged_outputs), + ) + ancestors_bitsets |= ( + (n, n_anc | merged_ancestors) + for n, n_anc in ancestors_bitsets.items() + if n_anc & merged_bitset ) - sorted_subgraphs.insert( - -(index + 1), - (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), - ) - # Add subgraph to all_subgraphs_bitset - all_subgraphs_bitset |= subgraph_bitset + # Update locals for next iteration + bitset_i = merged_bitset + inputs_i = tuple(merged_inputs) + outputs_i = merged_outputs + merged = True + + if merged: + # Write back so other input lists see the merged state + sibling_candidates[sibling_i] = ( + bitset_i, + (inputs_i, outputs_i), + ) + + # Update sorted_subgraphs + merged_entry = (bitset_i, (inputs_i, outputs_i)) + + # Replace earliest existing sibling with merged graph; drop the rest + pos = next( + ( + sg_idx + for sg_idx, (sg_bitset, _) in enumerate( + sorted_subgraphs + ) + if sg_bitset & bitset_i + ), + None, + ) + if pos is not None: + sorted_subgraphs[:] = [ + *sorted_subgraphs[:pos], + merged_entry, + *( + e + for e in sorted_subgraphs[pos + 1 :] + if not (e[0] & bitset_i) + ), + ] + else: + # All siblings were singletons: need a sorted insertion + _insert_sorted_subgraph( + sorted_subgraphs, + all_subgraphs_bitset, + ancestors_bitsets[outputs_i[0].owner], + merged_entry, + ) + all_subgraphs_bitset |= bitset_i - # Finished exploring the whole graph - # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) max_operands = elemwise_max_operands_fct(None) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py new file mode 100644 index 0000000000..febc461b0c --- /dev/null +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -0,0 +1,814 @@ +"""Fuse indexed reads and updates into Elemwise iteration loops. + +Introduces ``IndexedElemwise``, an ``OpFromGraph`` that wraps +``AdvancedSubtensor1`` + ``Elemwise`` + ``AdvancedIncSubtensor1`` subgraphs +so the Numba backend can generate a single loop with indirect indexing, +eliminating materialised intermediate arrays. +""" + +from pytensor.compile import optdb +from pytensor.compile.builders import OpFromGraph +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import GraphRewriter, dfs_rewriter +from pytensor.graph.rewriting.db import SequenceDB +from pytensor.printing import op_debug_information +from pytensor.scalar.basic import Composite, identity +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.shape import Reshape, shape_padright +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + indices_from_subtensor, +) + + +@node_rewriter([DimShuffle]) +def undo_take_dimshuffle_for_fusion(fgraph, node): + """Undo ``DimShuffle(AdvancedSubtensor1(DimShuffle(x), idx))`` -> ``AdvancedSubtensor(x, :, ..., idx, :, ...)``. + + The ``local_replace_AdvancedSubtensor`` specialize rewrite converts + ``x[:, idx]`` into ``x.T[idx].T`` (axis-swap + AdvancedSubtensor1 + + axis-swap). This rewrite undoes that when the result feeds a single + Elemwise, so ``FuseIndexedElemwise`` can absorb the indexing directly + on the correct axis. + + See also ``undo_take_reshape_for_fusion`` which handles the analogous + Reshape+flatten pattern for ND indices. + """ + # Outer DimShuffle must be an axis swap + outer_ds = node.op + if outer_ds.augment or outer_ds.drop: + return None + order = outer_ds.new_order + ndim = len(order) + if ndim < 2: + return None + + # Find the swapped axis: exactly two positions differ from identity + swapped = [i for i in range(ndim) if order[i] != i] + if len(swapped) != 2: + return None + ax_a, ax_b = swapped + if order[ax_a] != ax_b or order[ax_b] != ax_a: + return None + axis = max(ax_a, ax_b) # the non-zero axis (0 was swapped to axis) + + # Inner must be AdvancedSubtensor1 + inner = node.inputs[0] + if inner.owner is None or not isinstance(inner.owner.op, AdvancedSubtensor1): + return None + asub1_node = inner.owner + + # AdvancedSubtensor1's input must be the inverse DimShuffle. + # For a pair swap, the inverse is the same permutation (swap is self-inverse). + # A general permutation would need argsort(order) here. + inner_ds_var = asub1_node.inputs[0] + if inner_ds_var.owner is None or not isinstance(inner_ds_var.owner.op, DimShuffle): + return None + inner_ds = inner_ds_var.owner.op + if inner_ds.new_order != tuple(order): + return None + + # Both intermediates must be single-client + if len(fgraph.clients[inner]) != 1: + return None + if len(fgraph.clients[inner_ds_var]) != 1: + return None + + # Outer DimShuffle must be consumed only by a single Elemwise + clients = fgraph.clients[node.outputs[0]] + if len(clients) != 1: + return None + client_node, _client_idx = clients[0] + if not isinstance(getattr(client_node, "op", None), Elemwise): + return None + + # Build AdvancedSubtensor: x[:, ..., idx, :, ...] + source = inner_ds_var.owner.inputs[0] + idx_var = asub1_node.inputs[1] + + idx_list = [slice(None)] * ndim + idx_list[axis] = 0 # pointer to the single index variable + new_out = AdvancedSubtensor(idx_list=idx_list)(source, idx_var) + return [new_out] + + +@node_rewriter([Reshape]) +def undo_take_reshape_for_fusion(fgraph, node): + """Undo ``Reshape(AdvancedSubtensor1(x, flatten(idx)), shape)`` for ND indices. + + ``transform_take`` rewrites ``x[mat_idx]`` (ND integer index) into + ``AdvancedSubtensor1(x, mat_idx.ravel()).reshape(mat_idx.shape + ...)``, + possibly with DimShuffle axis-swaps for non-zero axes. This rewrite + undoes that so ``FuseIndexedElemwise`` can absorb the ND index directly. + """ + [reshape_out] = node.outputs + + # Must feed a single Elemwise (or chain to one via another pre-fusion rewrite) + clients = fgraph.clients[reshape_out] + if len(clients) != 1: + return None + client_node, _ = clients[0] + if not isinstance(getattr(client_node, "op", None), Elemwise): + return None + + inner = node.inputs[0] + if inner.owner is None: + return None + + # --- Detect axis-0 pattern: Reshape(AdvancedSubtensor1(src, flatten(idx)), shape) + # --- Detect axis>0 pattern: Reshape(DimShuffle(AdvancedSubtensor1(DimShuffle(src), flatten(idx))), shape) + axis = 0 + asub1_node = None + + if isinstance(inner.owner.op, AdvancedSubtensor1): + asub1_node = inner.owner + elif isinstance(inner.owner.op, DimShuffle): + # Check for axis-swap DimShuffle wrapping AdvancedSubtensor1 + outer_ds = inner.owner.op + if outer_ds.augment or outer_ds.drop: + return None + order = outer_ds.new_order + ndim_ds = len(order) + if ndim_ds < 2: + return None + swapped = [i for i in range(ndim_ds) if order[i] != i] + if len(swapped) != 2: + return None + ax_a, ax_b = swapped + if order[ax_a] != ax_b or order[ax_b] != ax_a: + return None + axis = max(ax_a, ax_b) + + ds_inner = inner.owner.inputs[0] + if ds_inner.owner is None or not isinstance( + ds_inner.owner.op, AdvancedSubtensor1 + ): + return None + asub1_node = ds_inner.owner + + # AdvancedSubtensor1's source must be the inverse DimShuffle + src_var = asub1_node.inputs[0] + if src_var.owner is None or not isinstance(src_var.owner.op, DimShuffle): + return None + if src_var.owner.op.new_order != tuple(order): + return None + + # Intermediates must be single-client + if len(fgraph.clients[ds_inner]) != 1: + return None + if len(fgraph.clients[src_var]) != 1: + return None + else: + return None + + if asub1_node is None: + return None + + # The index input to AdvancedSubtensor1 must be Reshape{1}(mat_idx, [-1]) (flatten) + flat_idx = asub1_node.inputs[1] + if flat_idx.owner is None or not isinstance(flat_idx.owner.op, Reshape): + return None + if flat_idx.owner.op.ndim != 1: + return None + mat_idx = flat_idx.owner.inputs[0] + if mat_idx.ndim < 2: + return None + + # AdvancedSubtensor1 output must be single-client + if len(fgraph.clients[asub1_node.outputs[0]]) != 1: + return None + + # Recover the original source (unwrap inner DimShuffle if axis > 0) + if axis > 0: + source = asub1_node.inputs[0].owner.inputs[0] + else: + source = asub1_node.inputs[0] + + # Build AdvancedSubtensor: source[:, ..., mat_idx, :, ...] + src_ndim = source.type.ndim + idx_list = [slice(None)] * src_ndim + idx_list[axis] = 0 # pointer to the single index variable + new_out = AdvancedSubtensor(idx_list=idx_list)(source, mat_idx) + return [new_out] + + +indexed_elemwise_optdb = SequenceDB() +optdb.register( + "fuse_indexed_into_elemwise", + indexed_elemwise_optdb, + "numba", + # After inplace_elemwise (position=50.5) so we see final inplace patterns, + # same position as other numba-specific rewrites (BlockwiseWithCoreShape). + position=100, +) + +indexed_elemwise_optdb.register( + "undo_take_dimshuffle_for_fusion", + dfs_rewriter(undo_take_dimshuffle_for_fusion), + "numba", + position=0, +) + +indexed_elemwise_optdb.register( + "undo_take_reshape_for_fusion", + dfs_rewriter(undo_take_reshape_for_fusion), + "numba", + position=0.5, +) + + +class IndexedElemwise(OpFromGraph): + """Fuse indexed reads and updates into a single Elemwise iteration loop. + + Absorbs ``AdvancedSubtensor1`` (indexed reads on inputs) and + ``AdvancedIncSubtensor1`` (indexed updates on outputs) into one loop, + avoiding materialisation of intermediate arrays. + + Inner fgraph contains the unfused subgraph. + Non-Numba backends run it as-is via ``OpFromGraph.perform``. + The Numba backend generates a single loop with indirect indexing. + + Outer inputs are ordered as:: + + [elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...] + + Elemwise inputs whose values are read via an index have their source + arrays substituted in place. + + Parameters + ---------- + indexed_inputs : tuple of (tuple[int, ...], int, tuple[bool, ...]) + One entry per index array: ``(elemwise_input_positions, axis, idx_broadcastable)``. + indexed_outputs : tuple of ((tuple[int, ...], str, int) | None) + One entry per index array: ``(output_positions, mode, axis)`` or ``None``. + """ + + def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): + self.indexed_inputs = indexed_inputs + self.indexed_outputs = indexed_outputs + super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) + + def __str__(self): + for node in self.fgraph.apply_nodes: + if isinstance(node.op, Elemwise): + return f"IndexedElemwise{{{node.op!s}}}" + return "IndexedElemwise" + + +@op_debug_information.register(IndexedElemwise) +def _op_debug_information_IndexedElemwise(op, node): + info = {} + + n_idx = len(op.indexed_inputs) + n_update_targets = sum(1 for e in op.indexed_outputs if e is not None) + n_elemwise = len(node.inputs) - n_idx - n_update_targets + + # Annotate indexed-read inputs + for k, (positions, _axis, _bc) in enumerate(op.indexed_inputs): + idx_label = f"idx_{k}" + for pos in positions: + if pos < len(node.inputs): + info[node.inputs[pos]] = f"indexed read ({idx_label})" + + # Annotate index arrays (after elemwise inputs) + for k in range(n_idx): + idx_pos = n_elemwise + k + if idx_pos < len(node.inputs): + info[node.inputs[idx_pos]] = f"idx_{k}" + + # Annotate update targets and outputs + buf_counter = 0 + target_start = n_elemwise + n_idx + target_offset = 0 + for k, entry in enumerate(op.indexed_outputs): + if entry is None: + continue + out_positions, mode, _axis = entry + buf_label = f"buf_{buf_counter}" + buf_counter += 1 + idx_label = f"idx_{k}" + + target_pos = target_start + target_offset + target_offset += 1 + if target_pos < len(node.inputs): + info[node.inputs[target_pos]] = buf_label + + for out_idx in out_positions: + if out_idx < len(node.outputs): + info[node.outputs[out_idx]] = ( + f"indexed {mode} ({buf_label}, {idx_label})" + ) + + return {node: info} + + +class FuseIndexedElemwise(GraphRewriter): + """Fuse indexed reads and indexed updates into Elemwise loops. + + Absorbs single-client ``AdvancedSubtensor1`` on inputs (indexed reads) + and single-client ``AdvancedIncSubtensor1`` on outputs (indexed updates) + into the Elemwise iteration, avoiding intermediate arrays. + + Supports multiple index arrays: e.g. ``x[idx_a] + y[idx_b]`` produces + two index groups. Index arrays are shared between reads and updates + when they refer to the same variable. + """ + + def apply(self, fgraph): + def _get_indexed_read_info(var): + """Extract indexed-read info from a variable. + + Returns ``(source, [(idx_var, axis), ...])`` or ``None``. + Handles: + - ``AdvancedSubtensor1(source, idx)`` -> single index on axis 0 + - ``AdvancedSubtensor(source, :, ..., idx, :, ...)`` -> single or + multi-index with tensor indices on consecutive axes, + followed only by full slices. + """ + if var.owner is None: + return None + op = var.owner.op + if isinstance(op, AdvancedSubtensor1): + return (var.owner.inputs[0], [(var.owner.inputs[1], 0)]) + if isinstance(op, AdvancedSubtensor): + indices = indices_from_subtensor(var.owner.inputs[1:], op.idx_list) + # Collect consecutive advanced (tensor) indices + adv = [] + for i, idx in enumerate(indices): + if idx == slice(None): + if adv: + break # trailing slices after the advanced group + elif hasattr(idx, "ndim") and idx.dtype != "bool": + if adv and adv[-1][1] != i - 1: + return None # non-consecutive + adv.append((idx, i)) + else: + return None # unsupported index type + if not adv: + return None + # Verify only full slices remain after the advanced group + for idx in indices[adv[-1][1] + 1 :]: + if idx != slice(None): + return None + return (var.owner.inputs[0], adv) + return None + + def find_indexed_input_groups(fgraph, node): + """Find single-client indexed-read inputs grouped by (index, axis). + + Returns ``[(idx_var, axis, (pos, ...))]`` -- one entry per distinct + ``(idx_var, axis)`` pair. A multi-index input contributes multiple + entries (one per indexed axis). + """ + groups = {} # (idx_var, axis) -> (idx_var, axis, list of positions) + for i, inp in enumerate(node.inputs): + info = _get_indexed_read_info(inp) + if info is None: + continue + if any(c is not node for c, _ in fgraph.clients[inp]): + continue + _source, idx_axis_pairs = info + for idx_var, axis in idx_axis_pairs: + key = (idx_var, axis) + if key not in groups: + groups[key] = (idx_var, axis, []) + groups[key][2].append(i) + + return [(var, axis, tuple(pos)) for var, axis, pos in groups.values()] + + def _get_indexed_update_info(client_node): + """Extract indexed-update info from an AdvancedInc node. + + Returns ``(target, [(idx_var, axis), ...], mode)`` or ``None``. + """ + op = client_node.op + if isinstance(op, AdvancedIncSubtensor1): + target, _val, idx_var = client_node.inputs + mode = "set" if op.set_instead_of_inc else "inc" + return (target, [(idx_var, 0)], mode) + if isinstance(op, AdvancedIncSubtensor): + target = client_node.inputs[0] + _val = client_node.inputs[1] + index_vars = client_node.inputs[2:] + indices = indices_from_subtensor(index_vars, op.idx_list) + adv = [] + for i, idx in enumerate(indices): + if idx == slice(None): + if adv: + break + elif hasattr(idx, "ndim") and idx.dtype != "bool": + if adv and adv[-1][1] != i - 1: + return None + adv.append((idx, i)) + else: + return None + if not adv: + return None + for idx in indices[adv[-1][1] + 1 :]: + if idx != slice(None): + return None + mode = "set" if op.set_instead_of_inc else "inc" + return (target, adv, mode) + return None + + def find_indexed_update_consumers(fgraph, node): + """Find indexed-update consumers of Elemwise outputs. + + Returns ``{out_idx: (update_node, target, [(idx_var, axis), ...], mode)}``. + Only considers outputs that are the value input (position 1) of + the indexed update. + """ + update_info = {} + for out_idx, out in enumerate(node.outputs): + clients = fgraph.clients[out] + inc_clients = [ + (c, ci) + for c, ci in clients + if ci == 1 + and isinstance(c.op, AdvancedIncSubtensor1 | AdvancedIncSubtensor) + ] + if len(inc_clients) != 1: + continue + [(client_node, _)] = inc_clients + info = _get_indexed_update_info(client_node) + if info is None: + continue + target, idx_axis_pairs, mode = info + # Don't fuse if val is broadcastable on any index loop dim — + # the Elemwise result is constant across the index and + # recomputing it per index position is wasteful. + val = client_node.inputs[1] + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) + val_idx_bc = list(val.type.broadcastable)[:n_idx_dims] + if any(val_idx_bc): + continue + + # Check broadcast compatibility between val and target on + # non-indexed axes. Val may have fewer non-indexed dims + # than target — the extra target dims become core dims + # (broadcast output) — but ONLY when the indexed axes are + # rightmost in the target so that val's dims align with the + # index loop (numpy broadcasts from the right). When + # indexed axes are NOT rightmost, val's dims would align + # with trailing non-indexed axes instead, giving wrong + # results if fused. + indexed_axes = {a for _, a in idx_axis_pairs} + max_indexed_axis = max(indexed_axes) + has_trailing_non_indexed = any( + a > max_indexed_axis + for a in range(target.type.ndim) + if a not in indexed_axes + ) + non_indexed_target_bc = [ + bc + for i, bc in enumerate(target.type.broadcastable) + if i not in indexed_axes + ] + non_indexed_val_bc = list(val.type.broadcastable) + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) + non_indexed_val_bc = non_indexed_val_bc[n_idx_dims:] + if len(non_indexed_val_bc) < len(non_indexed_target_bc): + # Val has fewer dims — broadcasting into extra target + # dims. Only safe when indexed axes are rightmost + # (no trailing non-indexed dims). + # TODO: could also handle trailing non-indexed dims + # that are broadcastable (size-1) in the val — squeeze + # them away before fusing. Requires excluding + # local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 + # to encounter these graphs naturally. + if has_trailing_non_indexed: + continue + elif any( + vbc and not tbc + for vbc, tbc in zip( + non_indexed_val_bc, non_indexed_target_bc, strict=False + ) + ): + continue + update_info[out_idx] = ( + client_node, + target, + idx_axis_pairs, + mode, + ) + return update_info + + for node in reversed(fgraph.toposort()): + if not isinstance(node.op, Elemwise): + continue + + read_groups = find_indexed_input_groups(fgraph, node) + update_consumers = find_indexed_update_consumers(fgraph, node) + + if not read_groups and not update_consumers: + continue + + indexed_positions = { + p for _, _ax, positions in read_groups for p in positions + } + + # If any inplace targets an indexed-read input, strip and re-run + # inplace with those inputs protected + if any( + inp_idx in indexed_positions + for inp_idx in node.op.inplace_pattern.values() + ): + stripped_node = Elemwise(node.op.scalar_op).make_node(*node.inputs) + fgraph.replace_all_validate( + list(zip(node.outputs, stripped_node.outputs)), + reason="fuse_indexed_elemwise_strip_inplace", + ) + + protected = frozenset( + stripped_node.inputs[i] for i in indexed_positions + ) + node = InplaceElemwiseOptimizer().try_inplace_on_node( + fgraph, + stripped_node, + reason="fuse_indexed_elemwise_reinplace", + extra_protected_inputs=protected, + ) + # Re-detect after inplace change + update_consumers = find_indexed_update_consumers(fgraph, node) + + # Merge read and update index arrays into a unified ordered list + all_idx_groups = {} # (idx_var, axis) -> (idx_var, position) + for idx_var, _ax, _ in read_groups: + key = (idx_var, _ax) + if key not in all_idx_groups: + all_idx_groups[key] = (idx_var, len(all_idx_groups)) + for _un, _target, idx_axis_pairs, _mode in update_consumers.values(): + for idx_var, axis in idx_axis_pairs: + key = (idx_var, axis) + if key not in all_idx_groups: + all_idx_groups[key] = (idx_var, len(all_idx_groups)) + + n_indices = len(all_idx_groups) + idx_vars = [None] * n_indices + for idx_var, pos in all_idx_groups.values(): + idx_vars[pos] = idx_var + + # Build destroy_map + outer_destroy_map = {} + for out_idx, inp_idx in node.op.inplace_pattern.items(): + if inp_idx not in indexed_positions and out_idx not in update_consumers: + outer_destroy_map[out_idx] = [inp_idx] + + # Inner fgraph inputs: + # [elemwise_inputs (sources substituted)..., idx_0, ..., target_0, ...] + inner_inputs = [ + inp.owner.inputs[0] if i in indexed_positions else inp + for i, inp in enumerate(node.inputs) + ] + inner_inputs = inner_inputs + idx_vars + + # If any scatter output also has other consumers, duplicate + # the elemwise output via Composite so the scatter can replace + # the duplicate while the original stays available. + multi_client_outs = set() + for out_idx in update_consumers: + update_node = update_consumers[out_idx][0] + if any( + c is not update_node + for c, _ in fgraph.clients[node.outputs[out_idx]] + ): + multi_client_outs.add(out_idx) + + if multi_client_outs: + # Rebuild the Elemwise with duplicated outputs. + # e.g. Exp(x) -> Composite([x], [exp(x), exp(x)])(x) + # so the Elemwise produces [out0, out1] where out1 is + # available for the scatter to replace. + scalar_op = node.op.scalar_op + if isinstance(scalar_op, Composite): + s_inputs = list(scalar_op.inputs) + s_outputs = list(scalar_op.outputs) + else: + scalar_node = scalar_op.make_node( + *[inp.type.to_scalar_type()() for inp in node.inputs] + ) + s_inputs = list(scalar_node.inputs) + s_outputs = list(scalar_node.outputs) + + # Map from original out_idx to the new duplicate out_idx. + # Wrap duplicates with identity so Composite._cleanup_graph + # doesn't clone the entire subgraph for repeated outputs. + # TODO: _cleanup_graph should use identity instead of clone + # for duplicate outputs. + dup_map = {} + for out_idx in sorted(multi_client_outs): + dup_map[out_idx] = len(s_outputs) + s_outputs.append(identity(s_outputs[out_idx])) + + new_scalar_op = Composite(s_inputs, s_outputs) + new_elemwise = Elemwise(new_scalar_op)(*node.inputs, return_list=True) + + # Update node reference and remap outputs + old_node = node + node = new_elemwise[0].owner + + # Rebuild inner_inputs with the new node's inputs + inner_inputs = [ + inp.owner.inputs[0] if i in indexed_positions else inp + for i, inp in enumerate(node.inputs) + ] + inner_inputs = inner_inputs + idx_vars + else: + dup_map = {} + + # Inner fgraph outputs; add update targets + # For non-axis-0 indexed updates, transpose the target so the + # indexed axis moves to position 0 and core dims are trailing + # (matching the gufunc convention assumed by the codegen). + inner_outputs = list(node.outputs) + call_inputs = list(inner_inputs) + scatter_transpose_back = {} # scatter_idx -> inverse perm + # Save original idx_axis_pairs before transpose may remap axes, + # so we can look up the correct key in all_idx_groups later. + original_idx_axis_pairs = { + out_idx: list(entry[2]) for out_idx, entry in update_consumers.items() + } + for out_idx in sorted(update_consumers.keys()): + update_node, target, idx_axis_pairs, _mode = update_consumers[out_idx] + + # If there are batch dims on the target (beyond the elemwise loop), + # we transpose them to the right, so they work as "core_ndim" + # in the vectorized codegen (which must always be on the right) + # (e.g. target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)), + idx_axes_set = {a for _, a in idx_axis_pairs} + max_idx_axis = max(idx_axes_set) + n_indexed_axes = len(idx_axes_set) + n_idx_dims = max((idx.ndim for idx, _ in idx_axis_pairs), default=0) + elemwise_batch_ndim = len(node.outputs[0].type.broadcastable) + source_batch = elemwise_batch_ndim + n_indexed_axes - n_idx_dims + needs_transpose = max_idx_axis >= source_batch + if needs_transpose: + # Build permutation: indexed axes first, then rest + idx_axes_sorted = sorted({a for _, a in idx_axis_pairs}) + non_idx_axes = [ + a for a in range(target.type.ndim) if a not in idx_axes_sorted + ] + perm = idx_axes_sorted + non_idx_axes + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + target = target.dimshuffle(perm) + # Remap idx_axis_pairs to the transposed axes + axis_remap = {old: new for new, old in enumerate(perm)} + idx_axis_pairs = [ + (idx_var, axis_remap[axis]) for idx_var, axis in idx_axis_pairs + ] + # Update the consumers entry with remapped axes + update_consumers[out_idx] = ( + update_node, + target, + idx_axis_pairs, + _mode, + ) + + inner_inputs.append(target) + + target_pos = len(call_inputs) + if update_node.op.inplace and not needs_transpose: + call_inputs.append(target) + else: + call_inputs.append(target.copy()) + + # Use the duplicate output for scatter if multi-client + scatter_idx = dup_map.get(out_idx, out_idx) + # The value to scatter: use the (possibly duplicated) Elemwise output + scatter_value = node.outputs[scatter_idx] + + # To work with IndexedElemwise codegen, we moved indexed update batch dims to the right. + # target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)) + # For this operation to be strictly valid, we need to add expand dims the scattered value. + # target(5,3)[:,idx].inc(exp(vector)) -> target(3, 5)[idx, :].inc(exp(vector)[:, None]) + # This expand_dims is subsumed by the IndexedElemwise, but makes the inner graph valid. + core_ndim = target.type.ndim - source_batch + if needs_transpose and core_ndim > 0: + scatter_value = shape_padright(scatter_value, core_ndim) + + # Build the scatter output + # After transpose, indexed axes are at position 0+, + # so always use AdvancedIncSubtensor1 for single axis-0. + if len(idx_axis_pairs) == 1 and idx_axis_pairs[0][1] == 0: + inplace_op = AdvancedIncSubtensor1( + inplace=True, + set_instead_of_inc=update_node.op.set_instead_of_inc, + ) + scatter_out = inplace_op( + target, + scatter_value, + idx_axis_pairs[0][0], + ) + elif update_node.op.inplace and not needs_transpose: + if isinstance(update_node.op, AdvancedIncSubtensor1): + scatter_out = update_node.op( + target, scatter_value, update_node.inputs[2] + ) + else: + scatter_out = update_node.op( + target, scatter_value, *update_node.inputs[2:] + ) + else: + inplace_op = AdvancedIncSubtensor( + idx_list=update_node.op.idx_list, + inplace=True, + set_instead_of_inc=update_node.op.set_instead_of_inc, + ) + scatter_out = inplace_op( + target, scatter_value, *update_node.inputs[2:] + ) + inner_outputs[scatter_idx] = scatter_out + outer_destroy_map[scatter_idx] = [target_pos] + + if needs_transpose: + scatter_transpose_back[scatter_idx] = inv_perm + + # Build indexed_inputs spec for the Op + indexed_inputs_spec = [None] * n_indices + for idx_var, _ax, positions in read_groups: + key = (idx_var, _ax) + _var, idx_pos = all_idx_groups[key] + indexed_inputs_spec[idx_pos] = ( + positions, + _ax, + idx_var.type.broadcastable, + ) + # Fill entries for index arrays with no reads (write-only) + for k in range(n_indices): + if indexed_inputs_spec[k] is None: + for (iv, ax), (v, p) in all_idx_groups.items(): + if p == k: + indexed_inputs_spec[k] = ((), ax, iv.type.broadcastable) + break + + # Build indexed_outputs spec for the Op. + # Use the (possibly remapped) axis from update_consumers for the + # spec value, but the ORIGINAL axis for the all_idx_groups lookup + # (since all_idx_groups was built before any transpose). + indexed_outputs_spec = [None] * n_indices + for out_idx in sorted(update_consumers.keys()): + _update_node, _target, idx_axis_pairs, mode = update_consumers[out_idx] + orig_pairs = original_idx_axis_pairs[out_idx] + for (idx_var, axis), (_orig_var, orig_axis) in zip( + idx_axis_pairs, orig_pairs + ): + key = (idx_var, orig_axis) + idx_pos = all_idx_groups.get(key, (None, None))[1] + if idx_pos is None: + continue + scatter_idx = dup_map.get(out_idx, out_idx) + if indexed_outputs_spec[idx_pos] is None: + indexed_outputs_spec[idx_pos] = ([scatter_idx], mode, axis) + else: + indexed_outputs_spec[idx_pos][0].append(scatter_idx) + indexed_outputs_spec = tuple( + (tuple(e[0]), e[1], e[2]) if e is not None else None + for e in indexed_outputs_spec + ) + + new_outs = IndexedElemwise( + inner_inputs, + inner_outputs, + destroy_map=outer_destroy_map, + indexed_inputs=tuple(indexed_inputs_spec), + indexed_outputs=indexed_outputs_spec, + )(*call_inputs, return_list=True) + + # Transpose back any scatter outputs that were transposed + for scatter_idx, inv_perm in scatter_transpose_back.items(): + new_outs[scatter_idx] = new_outs[scatter_idx].dimshuffle(inv_perm) + + # The node whose outputs we need to replace in the outer graph + orig_node = old_node if multi_client_outs else node + replacements = [] + for out_idx in range(len(orig_node.outputs)): + if out_idx in update_consumers: + update_node = update_consumers[out_idx][0] + scatter_idx = dup_map.get(out_idx, out_idx) + replacements.append((update_node.outputs[0], new_outs[scatter_idx])) + if out_idx in dup_map: + # Multi-client: also replace the raw elemwise output + replacements.append( + (orig_node.outputs[out_idx], new_outs[out_idx]) + ) + else: + replacements.append((orig_node.outputs[out_idx], new_outs[out_idx])) + + fgraph.replace_all_validate( + replacements, + reason="fuse_indexed_into_elemwise", + ) + + +indexed_elemwise_optdb.register( + "fuse_indexed_elemwise", + FuseIndexedElemwise(), + "numba", + position=1, +) diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 6bb9ed5bd9..88714e0b27 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -1,3 +1,5 @@ +# Import to trigger registration of IndexedElemwise rewrites +import pytensor.tensor.rewriting.indexed_elemwise # noqa: F401 from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.rewriting.basic import dfs_rewriter diff --git a/tests/benchmarks/test_gather_fusion.py b/tests/benchmarks/test_gather_fusion.py new file mode 100644 index 0000000000..9c545d08bf --- /dev/null +++ b/tests/benchmarks/test_gather_fusion.py @@ -0,0 +1,152 @@ +"""Micro-benchmarks for Elemwise fusion with indexed reads and updates. + +Tests the benefit of fusing AdvancedSubtensor1 (indexed reads) and +AdvancedIncSubtensor1 (indexed updates) into Elemwise loops, avoiding +materialization of intermediate arrays. +""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import config +from pytensor.compile.mode import get_mode +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.subtensor import AdvancedIncSubtensor1, advanced_subtensor1 + + +@pytest.fixture( + params=[ + (85, 919, 2, 6), # radon-like: small + (1000, 100_000, 2, 4), # medium + ], + ids=["small-85x919", "medium-1Kx100K"], +) +def gather_benchmark_setup(request): + n_bins, n_data, n_gathered, n_direct = request.param + + rng = np.random.default_rng(42) + idx = rng.integers(n_bins, size=n_data).astype(np.int64) + idx.sort() + + sources = [pt.vector(f"src_{i}", shape=(n_bins,)) for i in range(n_gathered)] + directs = [pt.vector(f"dir_{i}", shape=(n_data,)) for i in range(n_direct)] + + terms = [advanced_subtensor1(s, idx) for s in sources] + directs + out = terms[0] + for t in terms[1:]: + out = out + t + + inputs = sources + directs + numba_mode = get_mode("NUMBA") + + fn_fused = pytensor.function(inputs, out, mode=numba_mode, trust_input=True) + fn_unfused = pytensor.function( + inputs, + out, + mode=numba_mode.excluding("fuse_indexed_into_elemwise"), + trust_input=True, + ) + + assert any( + isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "IndexedElemwise not found in fused graph" + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "IndexedElemwise found in unfused graph" + + rng = np.random.default_rng(1) + vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] + + np.testing.assert_allclose(fn_fused(*vals), fn_unfused(*vals), rtol=1e-10) + + return fn_fused, fn_unfused, vals + + +def test_gather_fusion_fused(gather_benchmark_setup, benchmark): + fn_fused, _, vals = gather_benchmark_setup + fn_fused(*vals) # warmup + benchmark(fn_fused, *vals) + + +def test_gather_fusion_unfused(gather_benchmark_setup, benchmark): + _, fn_unfused, vals = gather_benchmark_setup + fn_unfused(*vals) # warmup + benchmark(fn_unfused, *vals) + + +# --------------------------------------------------------------------------- +# Scatter output fusion benchmarks +# --------------------------------------------------------------------------- + + +@pytest.fixture( + params=[ + (85, 919, 2, 6, "inc"), # radon-like: small, inc mode + (85, 919, 2, 6, "set"), # radon-like: small, set mode + (1000, 100_000, 2, 4, "inc"), # medium, inc mode + ], + ids=["small-85x919-inc", "small-85x919-set", "medium-1Kx100K-inc"], +) +def scatter_benchmark_setup(request): + n_bins, n_data, n_gathered, n_direct, mode = request.param + + rng = np.random.default_rng(42) + idx = rng.integers(n_bins, size=n_data).astype(np.int64) + idx.sort() + + sources = [pt.vector(f"src_{i}", shape=(n_bins,)) for i in range(n_gathered)] + directs = [pt.vector(f"dir_{i}", shape=(n_data,)) for i in range(n_direct)] + target = pt.vector("target", shape=(n_bins,)) + + terms = [advanced_subtensor1(s, idx) for s in sources] + directs + elemwise_out = terms[0] + for t in terms[1:]: + elemwise_out = elemwise_out + t + + if mode == "inc": + out = pt.inc_subtensor(target[idx], elemwise_out) + else: + out = pt.set_subtensor(target[idx], elemwise_out) + + inputs = sources + directs + [target] + numba_mode = get_mode("NUMBA") + + fn_fused = pytensor.function(inputs, out, mode=numba_mode, trust_input=True) + fn_unfused = pytensor.function( + inputs, + out, + mode=numba_mode.excluding("fuse_indexed_into_elemwise"), + trust_input=True, + ) + + assert any( + isinstance(n.op, IndexedElemwise) for n in fn_fused.maker.fgraph.toposort() + ), "IndexedElemwise not found in fused graph" + assert not any( + isinstance(n.op, AdvancedIncSubtensor1) + for n in fn_fused.maker.fgraph.toposort() + ), "AdvancedIncSubtensor1 still present in fused graph" + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn_unfused.maker.fgraph.toposort() + ), "IndexedElemwise found in unfused graph" + + rng = np.random.default_rng(1) + vals = [rng.normal(size=inp.type.shape).astype(config.floatX) for inp in inputs] + + np.testing.assert_allclose(fn_fused(*vals), fn_unfused(*vals), rtol=1e-10) + + return fn_fused, fn_unfused, vals + + +def test_scatter_fusion_fused(scatter_benchmark_setup, benchmark): + fn_fused, _, vals = scatter_benchmark_setup + fn_fused(*vals) # warmup + benchmark(fn_fused, *vals) + + +def test_scatter_fusion_unfused(scatter_benchmark_setup, benchmark): + _, fn_unfused, vals = scatter_benchmark_setup + fn_unfused(*vals) # warmup + benchmark(fn_unfused, *vals) diff --git a/tests/benchmarks/test_rewriting.py b/tests/benchmarks/test_rewriting.py index cea3c4e676..5b0b86ba6f 100644 --- a/tests/benchmarks/test_rewriting.py +++ b/tests/benchmarks/test_rewriting.py @@ -45,7 +45,7 @@ def _deep_small_kernels(n): "graph_fn, n, expected_n_repl", [ ("deep_small_kernels", 20, (20, 60)), - ("large_fuseable_graph", 25, (128, 876)), + ("large_fuseable_graph", 25, (55, 901)), ], ) def test_fusion_rewrite_benchmark(graph_fn, n, expected_n_repl, benchmark): diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index cb2c27fada..cdd079a2e1 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -491,6 +491,8 @@ def test_swap_SharedVariable_with_given(self): assert in1.value is in2.value def test_copy_delete_updates(self): + from contextlib import nullcontext + w = iscalar("w") x = fscalar("x") # SharedVariable for tests, one of them has update @@ -498,25 +500,35 @@ def test_copy_delete_updates(self): z = shared(value=2, name="z") out = x + y + z - # Test for different linkers - # for mode in ["FAST_RUN","FAST_COMPILE"]: - # second_time = False + inplace_warn = pytest.warns(UserWarning, match="will still be mutated") + for mode in ("FAST_RUN", "FAST_COMPILE"): + y.set_value(1) + z.set_value(2) + # FAST_RUN applies inplace rewrites, so z will be destroyed. + # A warning is issued because dropping the update doesn't prevent + # the shared variable storage from being mutated. + ctx = inplace_warn if mode == "FAST_RUN" else nullcontext() ori = function([x], out, mode=mode, updates={z: z * 2}) - cpy = ori.copy(delete_updates=True) - - assert cpy(1) == 4 - assert cpy(1) == 4 - assert cpy(1) == 4 + with ctx: + cpy = ori.copy(delete_updates=True) + if mode == "FAST_COMPILE": + assert cpy(1) == 4 + assert cpy(1) == 4 + assert cpy(1) == 4 # Test if unused implicit and explicit inputs from delete_updates # are ignored as intended. for mode in ("FAST_RUN", "FAST_COMPILE"): + ctx = inplace_warn if mode == "FAST_RUN" else nullcontext() + ori = function([x], x, mode=mode, updates={z: z * 2}) - cpy = ori.copy(delete_updates=True) + with ctx: + cpy = ori.copy(delete_updates=True) ori = function([x, w], x, mode=mode, updates={z: z + w}) - cpy = ori.copy(delete_updates=True) + with ctx: + cpy = ori.copy(delete_updates=True) def test_shared_state0(self): a = scalar() # the a is for 'anonymous' (un-named). diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py new file mode 100644 index 0000000000..049062d118 --- /dev/null +++ b/tests/link/numba/test_indexed_elemwise.py @@ -0,0 +1,710 @@ +"""Tests for IndexedElemwise fusion (indexed reads and updates in Elemwise loops).""" + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import config +from pytensor.compile.mode import get_mode +from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor1, + AdvancedSubtensor, + advanced_subtensor1, +) + + +numba = pytest.importorskip("numba") + +NUMBA_MODE = get_mode("NUMBA") +NUMBA_NO_FUSION = NUMBA_MODE.excluding("fuse_indexed_into_elemwise") + + +def fused_and_unfused(inputs, output): + """Compile fused and unfused versions of a graph.""" + fn = pytensor.function(inputs, output, mode=NUMBA_MODE, trust_input=True) + fn_u = pytensor.function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True) + return fn, fn_u + + +def assert_fused(fn): + """Assert that the compiled graph contains an IndexedElemwise node.""" + assert any(isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort()), ( + "IndexedElemwise not found in fused graph" + ) + + +# ============================================================ +# Correctness tests — indexed reads +# ============================================================ + + +class TestIndexedReadFusion: + """Test indexed reads (AdvancedSubtensor1 / AdvancedSubtensor) fused into Elemwise.""" + + def test_single_index_axis0(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + fn, fn_u = fused_and_unfused([x, y], advanced_subtensor1(x, idx) + y) + assert_fused(fn) + xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_single_index_axis1(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.matrix("x", shape=(3, 85)) + y = pt.matrix("y", shape=(3, 919)) + fn, fn_u = fused_and_unfused([x, y], x[:, idx] + y) + assert_fused(fn) + xv, yv = rng.normal(size=(3, 85)), rng.normal(size=(3, 919)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_multi_index_2d(self): + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused([x, ir, ic, y], x[ir, ic] + y) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + rv, cv = rng.integers(100, size=50), rng.integers(200, size=50) + yv = rng.normal(size=(50,)) + np.testing.assert_allclose(fn(xv, rv, cv, yv), fn_u(xv, rv, cv, yv), rtol=1e-10) + + def test_multi_index_3d_trailing_dim(self): + rng = np.random.default_rng(42) + x3 = pt.tensor3("x3", shape=(100, 200, 5)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + z = pt.matrix("z", shape=(50, 5)) + fn, fn_u = fused_and_unfused([x3, ir, ic, z], x3[ir, ic] + z) + assert_fused(fn) + xv = rng.normal(size=(100, 200, 5)) + rv, cv = rng.integers(100, size=50), rng.integers(200, size=50) + zv = rng.normal(size=(50, 5)) + np.testing.assert_allclose(fn(xv, rv, cv, zv), fn_u(xv, rv, cv, zv), rtol=1e-10) + + def test_multiple_gathered_sources(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x1 = pt.vector("x1", shape=(85,)) + x2 = pt.vector("x2", shape=(85,)) + y = pt.vector("y", shape=(919,)) + fn, fn_u = fused_and_unfused( + [x1, x2, y], advanced_subtensor1(x1, idx) + advanced_subtensor1(x2, idx) + y + ) + assert_fused(fn) + xv1, xv2, yv = ( + rng.normal(size=(85,)), + rng.normal(size=(85,)), + rng.normal(size=(919,)), + ) + np.testing.assert_allclose(fn(xv1, xv2, yv), fn_u(xv1, xv2, yv), rtol=1e-10) + + def test_broadcast_index_axis0(self): + """Static shape=(1,) index on axis 0 broadcasts against larger direct input.""" + x = pt.vector("x", shape=(100,)) + y = pt.vector("y", shape=(50,)) + idx = np.array([5], dtype=np.int64) # shape (1,), broadcastable + out = advanced_subtensor1(x, idx) + y + fn, fn_u = fused_and_unfused([x, y], out) + assert_fused(fn) + xv = np.arange(100, dtype="float64") + yv = np.ones(50) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_broadcast_index_axis1(self): + """Static shape=(1,) index on axis 1 broadcasts against larger direct input.""" + x = pt.matrix("x", shape=(3, 100)) + y = pt.matrix("y", shape=(3, 50)) + idx = np.array([5], dtype=np.int64) + out = x[:, idx] + y # x[:, idx] has shape (3, 1), broadcasts to (3, 50) + fn, fn_u = fused_and_unfused([x, y], out) + assert_fused(fn) + xv = np.arange(300.0).reshape(3, 100) + yv = np.ones((3, 50)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + def test_nd_index_axis0(self): + """2D matrix index on axis 0.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx, y], pt.exp(x[mat_idx]) + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, iv, yv), fn_u(xv, iv, yv), rtol=1e-10) + + def test_nd_index_axis1(self): + """2D matrix index on axis 1 (via undo_take_reshape_for_fusion).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(3, 100)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[:, mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(3, 100)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_with_trailing_dims(self): + """2D index on axis 0 with trailing source dims.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 7)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, mat_idx], pt.exp(x[mat_idx])) + assert_fused(fn) + xv = rng.normal(size=(100, 7)) + iv = rng.integers(100, size=(10, 5)).astype(np.int64) + np.testing.assert_allclose(fn(xv, iv), fn_u(xv, iv), rtol=1e-10) + + def test_nd_index_broadcast(self): + """Broadcastable 2D index (shape (1, C)) broadcasts against direct input.""" + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(100,)) + bc_idx = pt.matrix("bc_idx", dtype="int64", shape=(1, 5)) + y = pt.matrix("y", shape=(10, 5)) + fn, fn_u = fused_and_unfused([x, bc_idx, y], x[bc_idx] + y) + assert_fused(fn) + xv = rng.normal(size=(100,)) + bcv = rng.integers(100, size=(1, 5)).astype(np.int64) + yv = rng.normal(size=(10, 5)) + np.testing.assert_allclose(fn(xv, bcv, yv), fn_u(xv, bcv, yv), rtol=1e-10) + + def test_scalar_and_vector_index(self): + """x[scalar_idx, vector_idx] — 0-d index broadcasts with 1-d index.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, y], x[scalar_idx, vec_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, sv, vv, yv), fn_u(xv, sv, vv, yv), rtol=1e-10) + + def test_vector_and_scalar_index(self): + """x[vector_idx, scalar_idx] — reversed order.""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + y = pt.vector("y", shape=(50,)) + fn, fn_u = fused_and_unfused( + [x, vec_idx, scalar_idx, y], x[vec_idx, scalar_idx] + y + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + vv = rng.integers(100, size=50).astype(np.int64) + sv = np.array(rng.integers(200), dtype=np.int64) + yv = rng.normal(size=50) + np.testing.assert_allclose(fn(xv, vv, sv, yv), fn_u(xv, vv, sv, yv), rtol=1e-10) + + def test_scalar_and_vector_index_with_trailing_dim(self): + """x[scalar_idx, vector_idx] on a 3-d tensor with trailing dim.""" + rng = np.random.default_rng(42) + x = pt.tensor3("x", shape=(100, 200, 7)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + z = pt.matrix("z", shape=(50, 7)) + fn, fn_u = fused_and_unfused( + [x, scalar_idx, vec_idx, z], x[scalar_idx, vec_idx] + z + ) + assert_fused(fn) + xv = rng.normal(size=(100, 200, 7)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + zv = rng.normal(size=(50, 7)) + np.testing.assert_allclose(fn(xv, sv, vv, zv), fn_u(xv, sv, vv, zv), rtol=1e-10) + + def test_all_scalar_indices(self): + """AdvancedSubtensor with all 0-d indices (degenerate case).""" + rng = np.random.default_rng(42) + x = pt.matrix("x", shape=(100, 200)) + si0 = pt.scalar("si0", dtype="int64") + si1 = pt.scalar("si1", dtype="int64") + y = pt.scalar("y", dtype="float64") + adv_sub = AdvancedSubtensor(idx_list=(0, 1))(x, si0, si1) + fn, fn_u = fused_and_unfused([x, si0, si1, y], adv_sub + y) + assert_fused(fn) + xv = rng.normal(size=(100, 200)) + s0 = np.array(rng.integers(100), dtype=np.int64) + s1 = np.array(rng.integers(200), dtype=np.int64) + yv = np.array(rng.normal()) + np.testing.assert_allclose(fn(xv, s0, s1, yv), fn_u(xv, s0, s1, yv), rtol=1e-10) + + def test_negative_indices(self): + """Negative indices must be handled correctly (sign-extended, not zero-extended).""" + rng = np.random.default_rng(42) + # Use unknown static shape so negative indices can't be canonicalized away + x = pt.vector("x") + y = pt.vector("y") + idx = pt.vector("idx", dtype="int64") + fn, fn_u = fused_and_unfused([x, idx, y], x[idx] + y) + assert_fused(fn) + xv = rng.normal(size=100) + # Negative indices: -1 means last element, -2 second to last, etc. + idxv = np.array([-1, -2, -3, 0, 1], dtype=np.int64) + yv = rng.normal(size=5) + np.testing.assert_allclose(fn(xv, idxv, yv), fn_u(xv, idxv, yv), rtol=1e-10) + + +# ============================================================ +# Correctness tests — indexed updates +# ============================================================ + + +class TestIndexedUpdateFusion: + """Test indexed updates (AdvancedIncSubtensor1) fused into Elemwise.""" + + def test_no_fusion_when_idx_axes_outside_elemwise_loop(self): + """Don't fuse if the indexed axes are not within the Elemwise loop. + + Here the index is on axis 0 of target(5, 10), but the Elemwise + output (10,) corresponds to axis 1 (the non-indexed trailing axis). + The indexed axis doesn't overlap with the Elemwise computation, so + fusing would misalign which input dims map to which target dims. + """ + rng = np.random.default_rng(42) + idx = rng.integers(5, size=10).astype(np.int64) + target = pt.matrix("target", shape=(5, 10)) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(10,)) + elemwise_out = advanced_subtensor1(x, idx) + y # shape (10,) + out = pt.inc_subtensor(target[idx], elemwise_out) + fn, fn_u = fused_and_unfused([x, y, target], out) + # Write not fused — the Elemwise loop dim is the non-indexed axis, + # not the indexed axis. Read fusion may still create an + # IndexedElemwise, but the AdvancedIncSubtensor1 must remain outside. + assert any( + isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() + ) + xv = rng.normal(size=(85,)) + yv = rng.normal(size=(10,)) + tv = rng.normal(size=(5, 10)) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_no_fusion_when_val_broadcasts_along_index_dim(self): + """Don't fuse if val is broadcastable on the index loop dim. + + When val is constant across the index (e.g. shape (1,) with idx + of size 10), fusing would recompute the same Elemwise result at + every index position. Better to compute once and scatter. + """ + rng = np.random.default_rng(42) + idx = rng.integers(5, size=10).astype(np.int64) + target = pt.vector("target", shape=(5,)) + x = pt.tensor("x", shape=(1,)) + out = pt.inc_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([x, target], out) + # Write not fused — val broadcasts along the index loop dim. + assert any( + isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() + ) + xv = rng.normal(size=(1,)) + tv = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, tv), fn_u(xv, tv), rtol=1e-10) + + def test_broadcast_val_into_non_indexed_dims(self): + """Fuse when Elemwise output broadcasts into target's non-indexed dims. + + target(5, 3)[:, idx] += exp(val) — the Elemwise loop covers the + index axis (axis 1), and the scalar result broadcasts into the + non-indexed leading axis (axis 0, core_ndim=1). + + Requires excluding local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1 + so we keep AdvancedIncSubtensor on the correct axis. + """ + rng = np.random.default_rng(42) + idx = rng.integers(3, size=10).astype(np.int64) + target = pt.matrix("target", shape=(5, 3)) + val = pt.vector("val", shape=(10,)) + out = pt.inc_subtensor(target[:, idx], pt.exp(val)) + + mode = NUMBA_MODE.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + mode_u = NUMBA_NO_FUSION.excluding( + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" + ) + fn = pytensor.function([val, target], out, mode=mode, trust_input=True) + fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True) + assert_fused(fn) + valv = rng.normal(size=(10,)) + tv = rng.normal(size=(5, 3)) + np.testing.assert_allclose(fn(valv, tv), fn_u(valv, tv), rtol=1e-10) + + def test_inc_subtensor(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.inc_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn, fn_u = fused_and_unfused([x, y, t], out) + assert_fused(fn) + xv, yv, tv = ( + rng.normal(size=(85,)), + rng.normal(size=(919,)), + rng.normal(size=(85,)), + ) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_set_subtensor(self): + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.set_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn, fn_u = fused_and_unfused([x, y, t], out) + assert_fused(fn) + xv, yv, tv = ( + rng.normal(size=(85,)), + rng.normal(size=(919,)), + rng.normal(size=(85,)), + ) + np.testing.assert_allclose(fn(xv, yv, tv), fn_u(xv, yv, tv), rtol=1e-10) + + def test_target_not_modified_when_non_inplace(self): + """Non-inplace scatter should not modify the original target.""" + rng = np.random.default_rng(42) + idx = rng.integers(85, size=919).astype(np.int64) + x = pt.vector("x", shape=(85,)) + y = pt.vector("y", shape=(919,)) + t = pt.vector("t", shape=(85,)) + out = pt.inc_subtensor(t[idx], advanced_subtensor1(x, idx) + y) + fn = pytensor.function([x, y, t], out, mode=NUMBA_MODE, trust_input=True) + xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) + tv = rng.normal(size=(85,)) + tv_copy = tv.copy() + fn(xv, yv, tv) + np.testing.assert_array_equal(tv, tv_copy) + + def test_multi_index_inc_subtensor(self): + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_multi_index_set_subtensor(self): + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.set_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_multi_index_write_with_trailing_dims(self): + rng = np.random.default_rng(42) + target = pt.tensor3("target", shape=(100, 200, 5)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.matrix("x", shape=(50, 5)) + out = pt.inc_subtensor(target[ir, ic], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200, 5)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=(50, 5)) + np.testing.assert_allclose(fn(tv, rv, cv, xv), fn_u(tv, rv, cv, xv), rtol=1e-10) + + def test_combined_multi_index_read_write(self): + """Read and write share the same multi-index arrays.""" + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + source = pt.matrix("source", shape=(100, 200)) + ir = pt.vector("ir", dtype="int64", shape=(50,)) + ic = pt.vector("ic", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[ir, ic], source[ir, ic] + x) + fn, fn_u = fused_and_unfused([target, source, ir, ic, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + sv = rng.normal(size=(100, 200)) + rv = rng.integers(100, size=50).astype(np.int64) + cv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose( + fn(tv, sv, rv, cv, xv), fn_u(tv, sv, rv, cv, xv), rtol=1e-10 + ) + + def test_scalar_and_vector_index_inc(self): + """inc_subtensor with x[scalar_idx, vector_idx] — 0-d and 1-d indices.""" + rng = np.random.default_rng(42) + target = pt.matrix("target", shape=(100, 200)) + scalar_idx = pt.scalar("scalar_idx", dtype="int64") + vec_idx = pt.vector("vec_idx", dtype="int64", shape=(50,)) + x = pt.vector("x", shape=(50,)) + out = pt.inc_subtensor(target[scalar_idx, vec_idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, scalar_idx, vec_idx, x], out) + assert_fused(fn) + tv = rng.normal(size=(100, 200)) + sv = np.array(rng.integers(100), dtype=np.int64) + vv = rng.integers(200, size=50).astype(np.int64) + xv = rng.normal(size=50) + np.testing.assert_allclose(fn(tv, sv, vv, xv), fn_u(tv, sv, vv, xv), rtol=1e-10) + + +class TestShapeValidation: + """Test that mismatched index/input shapes raise runtime errors. + + All inputs use ``shape=(None,)`` so shapes are unknown at compile time. + The fused loop's ``compute_itershape`` must catch mismatches at runtime. + """ + + def test_mismatched_index_and_direct_input(self): + """Index length doesn't match direct input on the same loop dim.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = x[idx] + y + fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: idx=50, y=50 — should work + fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(50)) + # Mismatched: idx=50, y=49 — should error + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(49)) + + def test_mismatched_multi_index_lengths(self): + """Two index arrays in a multi-index have different lengths.""" + x = pt.matrix("x", shape=(None, None)) + ir = pt.vector("ir", dtype="int64", shape=(None,)) + ic = pt.vector("ic", dtype="int64", shape=(None,)) + y = pt.vector("y", shape=(None,)) + out = x[ir, ic] + y + fn = pytensor.function([x, ir, ic, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: all 50 — should work + fn( + np.zeros((100, 200)), + np.zeros(50, dtype=np.int64), + np.zeros(50, dtype=np.int64), + np.zeros(50), + ) + # Mismatched: ir=50, ic=49 — should error + with pytest.raises(Exception): + fn( + np.zeros((100, 200)), + np.zeros(50, dtype=np.int64), + np.zeros(49, dtype=np.int64), + np.zeros(50), + ) + + def test_mismatched_index_vs_direct_on_non_indexed_axis(self): + """Index and direct input disagree on a non-indexed (trailing) axis.""" + x = pt.tensor3("x", shape=(None, None, None)) + ir = pt.vector("ir", dtype="int64", shape=(None,)) + ic = pt.vector("ic", dtype="int64", shape=(None,)) + z = pt.matrix("z", shape=(None, None)) + out = x[ir, ic] + z # result shape (N, trailing) + fn = pytensor.function([x, ir, ic, z], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: trailing dim = 5 for both + fn( + np.zeros((10, 20, 5)), + np.zeros(3, dtype=np.int64), + np.zeros(3, dtype=np.int64), + np.zeros((3, 5)), + ) + # Mismatched: x trailing=5, z trailing=4 + with pytest.raises(Exception): + fn( + np.zeros((10, 20, 5)), + np.zeros(3, dtype=np.int64), + np.zeros(3, dtype=np.int64), + np.zeros((3, 4)), + ) + + def test_runtime_broadcast_on_index_dim(self): + """Symbolic shapes that happen to be 1 at runtime — broadcast check.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = x[idx] + y + fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Both idx and y have length 1 — should work (both agree on dim 0) + result = fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(1)) + assert result.shape == (1,) + # idx=1, y=5 — should error (shape mismatch, no static broadcast info) + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) + + def test_shared_read_write_index_no_broadcast(self): + """A shared read+write index must not broadcast at runtime.""" + x = pt.vector("x", shape=(None,)) + y = pt.vector("y", shape=(None,)) + t = pt.vector("t", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = pt.inc_subtensor(t[idx], x[idx] + y) + fn = pytensor.function([x, y, t, idx], out, mode=NUMBA_MODE, trust_input=True) + fn(np.zeros(100), np.zeros(50), np.zeros(100), np.arange(50, dtype=np.int64)) + with pytest.raises(Exception): + fn( + np.zeros(100), + np.zeros(50), + np.zeros(100), + np.array([0], dtype=np.int64), + ) + + def test_mismatched_nd_index_dims(self): + """ND index shape doesn't match direct input shape.""" + x = pt.vector("x", shape=(None,)) + mat_idx = pt.matrix("mat_idx", dtype="int64", shape=(None, None)) + y = pt.matrix("y", shape=(None, None)) + out = x[mat_idx] + y + fn = pytensor.function([x, mat_idx, y], out, mode=NUMBA_MODE, trust_input=True) + assert_fused(fn) + # Matching: mat_idx=(3,4), y=(3,4) — should work + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 4))) + # Mismatched: mat_idx=(3,4), y=(3,5) — should error + with pytest.raises(Exception): + fn(np.zeros(100), np.zeros((3, 4), dtype=np.int64), np.zeros((3, 5))) + + # ============================================================ + # Radon model integration test + # ============================================================ + + def test_write_only_index_no_broadcast(self): + """A write-only index must not broadcast against the loop. + + target[idx] += exp(y) where idx is write-only (not used for reads). + If idx has shape (1,) at runtime but y has shape (50,), the loop + runs 50 iterations. The index must not silently broadcast — it + should error because the shapes don't match. + """ + rng = np.random.default_rng(42) + y = pt.vector("y", shape=(None,)) + t = pt.vector("t", shape=(None,)) + idx = pt.vector("idx", dtype="int64", shape=(None,)) + out = pt.inc_subtensor(t[idx], pt.exp(y)) + fn = pytensor.function([y, t, idx], out, mode=NUMBA_MODE, trust_input=True) + fn_u = pytensor.function( + [y, t, idx], out, mode=NUMBA_NO_FUSION, trust_input=True + ) + assert_fused(fn) + # Matching shapes — should work + yv = rng.normal(size=50) + tv = rng.normal(size=100) + iv = rng.integers(100, size=50).astype(np.int64) + np.testing.assert_allclose(fn(yv, tv, iv), fn_u(yv, tv, iv), rtol=1e-10) + # Mismatched: idx=1, y=50 — should error + with pytest.raises(Exception): + fn(np.zeros(50), np.zeros(100), np.array([0], dtype=np.int64)) + + +class TestRepeatedAccumulationIndices: + """Test inc_subtensor with repeated indices (same position accumulated multiple times).""" + + def test_repeated_inc_subtensor(self): + """inc_subtensor with repeated indices should accumulate correctly.""" + rng = np.random.default_rng(42) + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(20,)) + # Indices with repeats: multiple values go to the same target position + idx = np.array( + [0, 1, 2, 0, 1, 2, 3, 3, 3, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3], + dtype=np.int64, + ) + out = pt.inc_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, x], out) + assert_fused(fn) + tv = rng.normal(size=10) + xv = rng.normal(size=20) + np.testing.assert_allclose(fn(tv, xv), fn_u(tv, xv), rtol=1e-10) + + def test_repeated_set_subtensor(self): + """set_subtensor with repeated indices -- last write wins.""" + rng = np.random.default_rng(42) + target = pt.vector("target", shape=(10,)) + x = pt.vector("x", shape=(5,)) + idx = np.array([0, 0, 1, 1, 2], dtype=np.int64) + out = pt.set_subtensor(target[idx], pt.exp(x)) + fn, fn_u = fused_and_unfused([target, x], out) + assert_fused(fn) + tv = rng.normal(size=10) + xv = rng.normal(size=5) + np.testing.assert_allclose(fn(tv, xv), fn_u(tv, xv), rtol=1e-10) + + def test_repeated_inc_with_read(self): + """Combined read + inc_subtensor with repeated indices.""" + rng = np.random.default_rng(42) + source = pt.vector("source", shape=(10,)) + target = pt.vector("target", shape=(10,)) + idx = np.array([0, 1, 0, 1, 2, 2, 3, 3], dtype=np.int64) + out = pt.inc_subtensor(target[idx], source[idx] * 2.0) + fn, fn_u = fused_and_unfused([source, target], out) + assert_fused(fn) + sv = rng.normal(size=10) + tv = rng.normal(size=10) + np.testing.assert_allclose(fn(sv, tv), fn_u(sv, tv), rtol=1e-10) + + +class TestRadonModel: + """Test fusion on the radon hierarchical model (logp + gradient).""" + + def test_radon_model_correctness(self): + import sys + + sys.path.insert(0, "tests/benchmarks") + from test_compilation import create_radon_model + + joined_inputs, [model_logp, model_dlogp] = create_radon_model() + fn, fn_u = fused_and_unfused([joined_inputs], [model_logp, model_dlogp]) + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + results_fused = fn(x) + results_unfused = fn_u(x) + for i, (rf, ru) in enumerate(zip(results_fused, results_unfused)): + np.testing.assert_allclose(rf, ru, rtol=1e-6, err_msg=f"Output {i}") + + def test_radon_model_no_unfused_indexing(self): + """After fusion, no AdvancedSubtensor1 or AdvancedIncSubtensor1 should remain.""" + import sys + + sys.path.insert(0, "tests/benchmarks") + from test_compilation import create_radon_model + + from pytensor.tensor.subtensor import AdvancedSubtensor1 + + joined_inputs, [model_logp, model_dlogp] = create_radon_model() + fn = pytensor.function( + [joined_inputs], + [model_logp, model_dlogp], + mode=NUMBA_MODE, + trust_input=True, + ) + nodes = fn.maker.fgraph.toposort() + assert not any(isinstance(n.op, AdvancedSubtensor1) for n in nodes) + assert not any(isinstance(n.op, AdvancedIncSubtensor1) for n in nodes) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 6dbaa75b85..4d77b17268 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -996,16 +996,13 @@ def test_expansion_order(self): (np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),), ("float32",), ), - pytest.param( - ( - (sin(exp(fx)), exp(sin(fx))), - (fx,), - (fxv,), - 1, - (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), - ("float32", "float32"), - ), - marks=pytest.mark.xfail, # Not implemented yet + ( + (sin(exp(fx)), exp(sin(fx))), + (fx,), + (fxv,), + 1, + (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), + ("float32", "float32"), ), ], ) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 406cd06f48..3f50c32425 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4651,7 +4651,9 @@ def test_polygamma_specialization(): y3 = polygamma(2, x) fn = pytensor.function( - [x], [y1, y2, y3], mode=get_default_mode().including("specialize") + [x], + [y1, y2, y3], + mode=get_default_mode().including("specialize").excluding("fusion"), ) fn_outs = fn.maker.fgraph.outputs assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)