diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index b59cc9992f..0c25934f18 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -10,6 +10,7 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar! """ +import abc import builtins import math from collections.abc import Callable @@ -36,7 +37,6 @@ from pytensor.utils import ( apply_across_args, difference, - to_return_values, ) @@ -1250,8 +1250,9 @@ def perform(self, node, inputs, output_storage): ): storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) + @abc.abstractmethod def impl(self, *inputs): - raise MethodNotDefined("impl", type(self), self.__class__.__name__) + raise NotImplementedError() def grad(self, inputs, output_gradients): raise MethodNotDefined("grad", type(self), self.__class__.__name__) @@ -3840,7 +3841,7 @@ class Real(UnaryScalarOp): """ - # numpy.real(float32) return a view on the inputs. + # numpy.real(float32) return a view on the inputs, which ain't good for elemwise. # nfunc_spec = ('real', 1, 1) def impl(self, x): @@ -4054,39 +4055,27 @@ def inner_outputs(self): @property def py_perform_fn(self): - if hasattr(self, "_py_perform_fn"): + """Compiled Python function that chains inner ops' ``impl`` methods. + + Returns a callable that takes scalar inputs and returns a tuple of outputs. + """ + try: return self._py_perform_fn + except AttributeError: + pass from pytensor.link.utils import fgraph_to_python - def python_convert(op, node=None, **kwargs): - assert node is not None - - n_outs = len(node.outputs) + def impl_convert(op, node=None, **kwargs): + return op.impl - if n_outs > 1: - - def _perform(*inputs, outputs=[[None]] * n_outs): - op.perform(node, inputs, outputs) - return tuple(o[0] for o in outputs) - - else: - - def _perform(*inputs, outputs=[[None]]): - op.perform(node, inputs, outputs) - return outputs[0][0] - - return _perform - - self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) + self._py_perform_fn = fgraph_to_python(self.fgraph, impl_convert) return self._py_perform_fn def impl(self, *inputs): - output_storage = [[None] for i in range(self.nout)] - self.perform(None, inputs, output_storage) - ret = to_return_values([storage[0] for storage in output_storage]) - if self.nout > 1: - ret = tuple(ret) + ret = self.py_perform_fn(*inputs) + if self.nout == 1: + return ret[0] return ret def c_code_cache_version(self): @@ -4311,9 +4300,7 @@ def make_node(self, *inputs): return node def perform(self, node, inputs, output_storage): - outputs = self.py_perform_fn(*inputs) - # zip strict not specified because we are in a hot loop - for storage, out_val in zip(output_storage, outputs): + for storage, out_val in zip(output_storage, self.py_perform_fn(*inputs)): storage[0] = out_val def grad(self, inputs, output_grads): diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 553c538296..d678b1850e 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2437,12 +2437,19 @@ def connection_pattern(self, node): return connection_pattern def L_op(self, inputs, outs, dC_douts): - # `grad_step` equals the number of steps the original scan node has - # done (if the original scan is a while loop than this number is the - # length of the output sequence) - # We do not know what kind of outputs the original scan has, so we - # try first to see if it has a nit_sot output, then a sit_sot and - # then a mit_sot + # Computes the gradient of this Scan by constructing a new backward Scan + # that runs in reverse. The method: + # 1. Differentiates the inner function symbolically (compute_all_gradients) + # 2. Adds accumulation terms for state inputs at preserved buffer positions + # 3. Builds reversed sequences from the forward outputs + # 4. Converts all recurrent states (sit-sot, mit-sot, mit-mot) into mit-mot + # form in the backward scan (initialized with output gradients, accumulate + # total gradients after evaluation) + # 5. Constructs and runs the backward Scan, then re-orders its outputs + + # Determine the number of gradient steps from the output shapes (not from + # inputs[0] directly, because while-loop scans may execute fewer steps than + # the allocated buffer size). info = self.info if info.n_nit_sot > 0: grad_steps = self.outer_nitsot_outs(outs)[0].shape[0] @@ -2457,8 +2464,7 @@ def L_op(self, inputs, outs, dC_douts): if info.as_while: n_steps = outs[0].shape[0] - # Restrict the number of grad steps according to - # self.truncate_gradient + # Restrict the number of grad steps according to self.truncate_gradient if self.truncate_gradient != -1: grad_steps = minimum(grad_steps, self.truncate_gradient) @@ -2540,13 +2546,11 @@ def compute_all_gradients(known_grads): ] gmp = {} - # Required in case there is a pair of variables X and Y, with X - # used to compute Y, for both of which there is an external - # gradient signal. Without this, the total gradient signal on X - # will be the external gradient signalknown_grads[X]. With this, - # it will be the sum of the external gradient signal and the - # gradient obtained by propagating Y's external gradient signal - # to X. + # The .copy() creates fresh variable nodes so that grad() treats them + # as new outputs "equal to" the originals, rather than matching them by + # identity to variables already in the graph. This forces grad() to + # propagate the known_grads values backward through the computation + # instead of short-circuiting at a wrt target. known_grads = {k.copy(): v for (k, v) in known_grads.items()} grads = grad( @@ -2588,17 +2592,15 @@ def compute_all_gradients(known_grads): Xt_placeholder = safe_new(Xt) Xts.append(Xt_placeholder) - # Different processing based on whether Xt is a nitsot output - # or not. NOTE : This cannot be done by using - # "if Xt not in self.inner_nitsot_outs(self_outputs)" because - # the exact same variable can be used as multiple outputs. + # Different processing based on whether Xt is a nitsot output or not. + # NOTE : This cannot be done by using "if Xt not in self.inner_nitsot_outs(self_outputs)" + # because the exact same variable can be used as multiple outputs. if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end: - # What we do here is loop through dC_douts and collect all + # loop through dC_douts and collect all # those that are connected to the specific one and do an # upcast on all of their dtypes to get the dtype for this # specific output. Deciding if the gradient with this - # specific previous step is defined or not is done somewhere - # else. + # specific previous step is defined or not is done somewhere else. dtypes = [] for pos, inp in enumerate(states): if inp in graph_inputs([Xt]): @@ -2637,9 +2639,9 @@ def compute_all_gradients(known_grads): continue # Just some trouble to avoid a +0 - if diff_outputs[i] in known_grads: + try: known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: + except KeyError: known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] dc_dxts_idx += 1 @@ -2655,6 +2657,9 @@ def compute_all_gradients(known_grads): ) else: disconnected_dC_dinps_t[dx] = False + # Replace inner output subexpressions with placeholders wired to the + # saved forward values, so the backward scan reuses them instead of + # recomputing them. See forced_replace docstring for details. for Xt, Xt_placeholder in zip( diff_outputs[info.n_mit_mot_outs :], Xts, strict=True ): @@ -2663,21 +2668,20 @@ def compute_all_gradients(known_grads): # construct dX_dtm1 dC_dXtm1s = [] + n_internal_recurrent_states = sum( + len(t) + for t in chain( + info.mit_mot_in_slices, + info.mit_sot_in_slices, + info.sit_sot_in_slices, + ) + ) for pos, x in enumerate(dC_dinps_t[info.n_seqs :]): - # Get the index of the first inner input corresponding to the - # pos-ieth inner input state + # Get the index of the first inner input corresponding to the pos-ieth inner input state idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos] - # Check if the pos-th input is associated with one of the - # recurrent states - x_is_state = pos < sum( - len(t) - for t in chain( - info.mit_mot_in_slices, - info.mit_sot_in_slices, - info.sit_sot_in_slices, - ) - ) + # Check if the pos-th input is associated with one of the recurrent states + x_is_state = pos < n_internal_recurrent_states if x_is_state and len(idxs) > 0: opos = idxs[0] @@ -2687,7 +2691,26 @@ def compute_all_gradients(known_grads): else: dC_dXtm1s.append(safe_new(x)) + # Skip accumulation for "overlapping" mit-mot taps. + # A mit-mot tap "overlaps" when the same tap index appears in both the input + # and output slices of a single mit-mot state. This means the output *overwrites* + # the input at that buffer position — analogous to set_subtensor(x, y, i). + # The gradient of an overwrite must zero out the direct pass-through from the + # old value; the only gradient path is through the output expression that replaced + # it (already captured by compute_all_gradients via known_grads). + overlapping_taps = set() + dx_offset = 0 + for idx in range(info.n_mit_mot): + in_taps = info.mit_mot_in_slices[idx] + out_taps = info.mit_mot_out_slices[idx] + for k, tap in enumerate(in_taps): + if tap in out_taps: + overlapping_taps.add(dx_offset + k) + dx_offset += len(in_taps) + for dx, dC_dXtm1 in enumerate(dC_dXtm1s): + if dx in overlapping_taps: + continue # gradient truncates here if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType): # The accumulated gradient is undefined pass @@ -2761,8 +2784,7 @@ def compute_all_gradients(known_grads): outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] - # Restrict the length of the outer sequences to the number of grad - # steps + # Restrict the length of the outer sequences to the number of grad steps outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs] inner_inp_seqs = self.inner_seqs(self_inputs) @@ -2771,7 +2793,14 @@ def compute_all_gradients(known_grads): inner_inp_seqs += self.inner_sitsot(self_inputs) inner_inp_seqs += self.inner_nitsot_outs(dC_dXts) inner_inp_seqs += Xts - # mitmot + # Build backward scan's mit-mot states. + # Every forward recurrent state (sit-sot, mit-sot, mit-mot) becomes + # a mit-mot in the backward scan. The conversion negates the taps: + # forward output tap t → backward input tap -t (gradient signal) + # forward input tap t → backward output tap -t (gradient to propagate) + # Each backward output tap also needs a backward input tap at the same + # position to carry the accumulated gradient (the recurrence). If one + # already exists from the first rule, they share the buffer slot. outer_inp_mitmot = [] inner_inp_mitmot = [] inner_out_mitmot = [] @@ -2810,8 +2839,8 @@ def compute_all_gradients(known_grads): inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) if isinstance(dC_dinps_t[ins_pos].type, NullType): - # We cannot use Null in the inner graph, so we - # use a zero tensor of the appropriate shape instead. + # We cannot use Null in the inner graph, + # so we use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX) ) @@ -2919,9 +2948,8 @@ def compute_all_gradients(known_grads): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: if isinstance(dC_dinps_t[ins_pos].type, NullType): - # Cannot use dC_dinps_t[ins_pos].dtype, so we use - # floatX instead, as it is a dummy value that will not - # be used anyway. + # Cannot use dC_dinps_t[ins_pos].dtype, so we use floatX instead, + # as it is a dummy value that will not be used anyway. outer_inp_mitmot.append( pt.zeros(outs[idx + offset].shape, dtype=config.floatX) ) @@ -2933,8 +2961,8 @@ def compute_all_gradients(known_grads): ) if isinstance(dC_dinps_t[ins_pos].type, NullType): - # We cannot use Null in the inner graph, so we - # use a zero tensor of the appropriate shape instead. + # We cannot use Null in the inner graph, + # so we use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX) ) @@ -2974,8 +3002,7 @@ def compute_all_gradients(known_grads): through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) - # Replace the inner output with a zero tensor of - # the right shape + # Replace the inner output with a zero tensor of the right shape inner_out_sitsot[_p] = pt.zeros( diff_inputs[ins_pos + _p].shape, dtype=config.floatX ) @@ -2993,8 +3020,7 @@ def compute_all_gradients(known_grads): through_untraced = True if isinstance(vl.type, NullType): type_outs.append(vl.type.why_null) - # Replace the inner output with a zero tensor of - # the right shape + # Replace the inner output with a zero tensor of the right shape inner_out_nitsot[_p] = pt.zeros( diff_inputs[_p].shape, dtype=config.floatX ) @@ -3089,9 +3115,8 @@ def compute_all_gradients(known_grads): ) ): if t == "connected": - # If the forward scan is in as_while mode, we need to pad - # the gradients, so that they match the size of the input - # sequences. + # If the forward scan is in as_while mode, we need to pad the gradients, + # so that they match the size of the input sequences. if info.as_while: n_zeros = inputs[0] - n_steps shp = (n_zeros,) @@ -3117,9 +3142,8 @@ def compute_all_gradients(known_grads): end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)): if t == "connected": - # If the forward scan is in as_while mode, we need to pad - # the gradients, so that they match the size of the input - # sequences. + # If the forward scan is in as_while mode, we need to pad the gradients, + # so that they match the size of the input sequences. if info.as_while: n_zeros = inputs[0] - grad_steps shp = (n_zeros,) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index abd6216110..b61f1b115e 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1279,6 +1279,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: The scan perform implementation takes the output sizes into consideration, saving the newest results over the oldest ones whenever the buffer is filled. + This rewrite must only run at compilation time, after grad() has already + built the backward scan. The backward scan needs all intermediate forward + states as sequence inputs (to evaluate f'(x[t])). If this rewrite truncates + buffers before grad() is called, the gradient will be silently wrong. + TODO: Use a subclass that raises explicitly on `L_op` + Paramaters ---------- backend_supports_output_pre_allocation: bool diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index 8ac3c58a99..b1426fa463 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -1099,9 +1099,22 @@ def __eq__(self, other): def forced_replace(out, x, y): """ - Check all internal values of the graph that compute the variable ``out`` - for occurrences of values identical with ``x``. If such occurrences are - encountered then they are replaced with variable ``y``. + Replace subexpressions in ``out`` that are structurally equal to ``x`` + with ``y``, using ``equal_computations`` for matching. + + Unlike ``graph_replace`` (which matches by variable identity), + this detects when a subexpression *recomputes* ``x`` without + being the same variable object. This is used by ``Scan.L_op`` + to substitute inner-function outputs with placeholders wired to + the saved forward values, avoiding redundant recomputation in + the backward scan. For example, if ``exp(x).L_op`` returns + ``output_gradient * exp(x)`` by recreating ``exp(x)`` instead + of referencing the existing output variable, a plain identity + check would miss it, but ``equal_computations`` catches it. + + This is not comprehensive: structurally different but semantically + equivalent expressions (e.g. ``exp(x + 0)`` vs ``exp(x)``) will + not match. Parameters ---------- @@ -1117,7 +1130,7 @@ def forced_replace(out, x, y): Notes ----- - When it find a match, it don't continue on the corresponding inputs. + When it finds a match, it does not continue into that node's inputs. """ if out is None: return None diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 9c8fa86ac7..ab281bd352 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -49,10 +49,17 @@ def _vectorize_node_perform( batch_bcast_patterns: Sequence[tuple[bool, ...]], batch_ndim: int, impl: str | None, + inplace_mapping: tuple[int | None, ...] | None = None, ) -> Callable: """Creates a vectorized `perform` function for a given core node. Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op. + + Parameters + ---------- + inplace_mapping + Optional tuple of length ``nout``. Entry ``i`` is the input index that + output ``i`` should be written into, or ``None`` to allocate a fresh array. """ storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} @@ -75,6 +82,7 @@ def _vectorize_node_perform( def vectorized_perform( *args, + out=None, batch_bcast_patterns=batch_bcast_patterns, batch_ndim=batch_ndim, single_in=single_in, @@ -82,7 +90,11 @@ def vectorized_perform( core_input_storage=core_input_storage, core_output_storage=core_output_storage, core_storage=core_storage, + inplace_mapping=inplace_mapping, ): + if inplace_mapping is not None: + out = tuple(args[j] if j is not None else None for j in inplace_mapping) + if single_in: batch_shape = args[0].shape[:batch_ndim] else: @@ -106,10 +118,22 @@ def vectorized_perform( for core_input, arg in zip(core_input_storage, args): core_input[0] = np.asarray(arg[index0]) core_thunk() - outputs = tuple( - empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) - for core_output in core_output_storage - ) + if out is None: + outputs = tuple( + empty( + batch_shape + core_output[0].shape, dtype=core_output[0].dtype + ) + for core_output in core_output_storage + ) + else: + outputs = tuple( + o + if o is not None + else empty( + batch_shape + core_output[0].shape, dtype=core_output[0].dtype + ) + for o, core_output in zip(out, core_output_storage) + ) for output, core_output in zip(outputs, core_output_storage): output[index0] = core_output[0] diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 96446237a4..19a1a5280f 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from copy import copy from textwrap import dedent from typing import Literal @@ -357,6 +356,8 @@ def __init__( assert not isinstance(scalar_op, type(self)) if inplace_pattern is None: inplace_pattern = frozendict({}) + elif not isinstance(inplace_pattern, frozendict): + inplace_pattern = frozendict(inplace_pattern) self.name = name self.scalar_op = scalar_op self.inplace_pattern = inplace_pattern @@ -365,22 +366,8 @@ def __init__( if nfunc_spec is None: nfunc_spec = getattr(scalar_op, "nfunc_spec", None) self.nfunc_spec = nfunc_spec - self.__setstate__(self.__dict__) super().__init__(openmp=openmp) - def __getstate__(self): - d = copy(self.__dict__) - d.pop("ufunc") - d.pop("nfunc") - d.pop("__epydoc_asRoutine", None) - return d - - def __setstate__(self, d): - super().__setstate__(d) - self.ufunc = None - self.nfunc = None - self.inplace_pattern = frozendict(self.inplace_pattern) - def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. @@ -599,53 +586,149 @@ def transform(r): return ret - def prepare_node(self, node, storage_map, compute_map, impl): - # Postpone the ufunc building to the last minutes due to: - # - NumPy ufunc support only up to 32 operands (inputs and outputs) - # But our c code support more. - # - nfunc is reused for scipy and scipy is optional - if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py": - impl = "c" + def _create_node_ufunc(self, node: Apply): + """Define (or retrieve) the node ufunc used in `perform`. - if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = import_func_from_string(self.nfunc_spec[0]) + For scalar (0-d) outputs, calls ``scalar_op.impl`` directly. + For tensor outputs with ``nfunc_spec``, uses the numpy/scipy ufunc. + Otherwise, ``np.frompyfunc`` (≤32 operands) or Blockwise vectorize (>32). - if ( - (len(node.inputs) + len(node.outputs)) <= 32 - and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) - and self.ufunc is None - and impl == "py" - ): + All returned callables accept ``(*inputs)`` and return a tuple of outputs. + The ``inplace_pattern`` is baked into the closure so that inplace outputs + are written directly into the corresponding input arrays. + + The ufunc is stored in the tag of the node. + """ + inplace_pattern = self.inplace_pattern + nout = len(node.outputs) + out_dtypes = tuple(out.type.numpy_dtype for out in node.outputs) + # Pre-compute output→input index mapping for inplace + out_to_in = ( + tuple(inplace_pattern.get(i) for i in range(nout)) + if inplace_pattern + else () + ) + + if (nfunc_spec := self.nfunc_spec) is not None and len( + node.inputs + ) == nfunc_spec[1]: + ufunc = import_func_from_string(nfunc_spec[0]) + if ufunc is None: + raise ValueError(f"Could not import gufunc {nfunc_spec[0]} for {self}") + # When inputs are discrete and output is float, pass a signature + # to prevent numpy from computing in float16 for int8 inputs + ufunc_kwargs = {} + if ( + isinstance(ufunc, np.ufunc) + and any(inp.dtype in discrete_dtypes for inp in node.inputs) + and any(out.dtype in float_dtypes for out in node.outputs) + ): + in_sig = "".join(np.dtype(inp.dtype).char for inp in node.inputs) + out_sig = "".join(np.dtype(out.dtype).char for out in node.outputs) + ufunc_kwargs["sig"] = f"{in_sig}->{out_sig}" + + if out_to_in and isinstance(ufunc, np.ufunc): + # Only numpy ufuncs support out=; other nfunc_spec functions (e.g. np.where) don't + if nout == 1: + + def ufunc_fn( + *inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs, _j=out_to_in[0] + ): + _ufunc(*inputs, out=inputs[_j], **_kwargs) + return (inputs[_j],) + else: + + def ufunc_fn( + *inputs, + _ufunc=ufunc, + _kwargs=ufunc_kwargs, + _out_to_in=out_to_in, + ): + out = tuple( + inputs[j] if j is not None else None for j in _out_to_in + ) + return _ufunc(*inputs, out=out, **_kwargs) + elif nout == 1: + + def ufunc_fn(*inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs): + return (_ufunc(*inputs, **_kwargs),) + else: + + def ufunc_fn(*inputs, _ufunc=ufunc, _kwargs=ufunc_kwargs): + return _ufunc(*inputs, **_kwargs) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + # No nfunc_spec path + if node.outputs[0].type.ndim == 0: + # Scalar outputs: call impl directly, wrap with np.asarray + impl = self.scalar_op.impl + if nout == 1: + + def ufunc_fn(*inputs, _impl=impl, _dt=out_dtypes[0]): + return (np.asarray(_impl(*inputs), dtype=_dt),) + else: + + def ufunc_fn(*inputs, _impl=impl, _dts=out_dtypes): + return tuple( + np.asarray(r, dtype=dt) for r, dt in zip(_impl(*inputs), _dts) + ) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + # ndim > 0 without nfunc_spec: frompyfunc (≤32 operands) or Blockwise vectorize (>32) + # frompyfunc returns object arrays — .astype() converts to the correct dtype. + # No inplace: destroy_map is a permission, not an obligation. frompyfunc + # already allocates an object array, so copying into the input would just waste time. + if len(node.inputs) + len(node.outputs) <= 32: ufunc = np.frompyfunc( self.scalar_op.impl, len(node.inputs), self.scalar_op.nout ) - if self.scalar_op.nin > 0: - # We can reuse it for many nodes - self.ufunc = ufunc + + if nout == 1: + + def ufunc_fn(*inputs, _ufunc=ufunc, _dt=out_dtypes[0]): + return (_ufunc(*inputs).astype(_dt),) else: - node.tag.ufunc = ufunc - - # Numpy ufuncs will sometimes perform operations in - # float16, in particular when the input is int8. - # This is not something that we want, and we do not - # do it in the C code, so we specify that the computation - # should be carried out in the returned dtype. - # This is done via the "sig" kwarg of the ufunc, its value - # should be something like "ff->f", where the characters - # represent the dtype of the inputs and outputs. - - # NumPy 1.10.1 raise an error when giving the signature - # when the input is complex. So add it only when inputs is int. - out_dtype = node.outputs[0].dtype - if ( - out_dtype in float_dtypes - and isinstance(self.nfunc, np.ufunc) - and node.inputs[0].dtype in discrete_dtypes - ): - char = np.dtype(out_dtype).char - sig = char * node.nin + "->" + char * node.nout - node.tag.sig = sig - node.tag.fake_node = Apply( + + def ufunc_fn(*inputs, _ufunc=ufunc, _dts=out_dtypes): + return tuple(r.astype(dt) for r, dt in zip(_ufunc(*inputs), _dts)) + else: + # frompyfunc limited to 32 operands, fall back to Blockwise vectorize + from pytensor.tensor.blockwise import _vectorize_node_perform + + core_node = Apply( + self.scalar_op, + [ + get_scalar_type(dtype=inp.type.dtype).make_variable() + for inp in node.inputs + ], + [ + get_scalar_type(dtype=out.type.dtype).make_variable() + for out in node.outputs + ], + ) + batch_ndim = node.outputs[0].type.ndim + batch_bcast_patterns = tuple(inp.type.broadcastable for inp in node.inputs) + ufunc_fn = _vectorize_node_perform( + core_node, + batch_bcast_patterns, + batch_ndim, + impl="py", + inplace_mapping=out_to_in or None, + ) + + node.tag.ufunc = ufunc_fn + return ufunc_fn + + def prepare_node(self, node, storage_map, compute_map, impl=None): + if impl != "c": + node.tag.ufunc = self._create_node_ufunc(node) + + # Create a dummy scalar node for the scalar_op to prepare itself + node.tag.dummy_node = dummy_node = Apply( self.scalar_op, [ get_scalar_type(dtype=input.type.dtype).make_variable() @@ -656,77 +739,18 @@ def prepare_node(self, node, storage_map, compute_map, impl): for output in node.outputs ], ) - - self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) + self.scalar_op.prepare_node(dummy_node, None, None, impl) def perform(self, node, inputs, output_storage): - if (len(node.inputs) + len(node.outputs)) > 32: - # Some versions of NumPy will segfault, other will raise a - # ValueError, if the number of operands in an ufunc is more than 32. - # In that case, the C version should be used, or Elemwise fusion - # should be disabled. - # FIXME: This no longer calls the C implementation! - super().perform(node, inputs, output_storage) - self._check_runtime_broadcast(node, inputs) - - ufunc_args = inputs - ufunc_kwargs = {} - # We supported in the past calling manually op.perform. - # To keep that support we need to sometimes call self.prepare_node - if self.nfunc is None and self.ufunc is None: - self.prepare_node(node, None, None, "py") - if self.nfunc and len(inputs) == self.nfunc_spec[1]: - ufunc = self.nfunc - nout = self.nfunc_spec[2] - if hasattr(node.tag, "sig"): - ufunc_kwargs["sig"] = node.tag.sig - # Unfortunately, the else case does not allow us to - # directly feed the destination arguments to the nfunc - # since it sometimes requires resizing. Doing this - # optimization is probably not worth the effort, since we - # should normally run the C version of the Op. - else: - # the second calling form is used because in certain versions of - # numpy the first (faster) version leads to segfaults - if self.ufunc: - ufunc = self.ufunc - elif not hasattr(node.tag, "ufunc"): - # It happen that make_thunk isn't called, like in - # get_underlying_scalar_constant_value - self.prepare_node(node, None, None, "py") - # prepare_node will add ufunc to self or the tag - # depending if we can reuse it or not. So we need to - # test both again. - if self.ufunc: - ufunc = self.ufunc - else: - ufunc = node.tag.ufunc - else: - ufunc = node.tag.ufunc - - nout = ufunc.nout + try: + ufunc = node.tag.ufunc + except AttributeError: + ufunc = node.tag.ufunc = self._create_node_ufunc(node) with np.errstate(all="ignore"): - variables = ufunc(*ufunc_args, **ufunc_kwargs) - - if nout == 1: - variables = [variables] - - # zip strict not specified because we are in a hot loop - for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs) - ): - storage[0] = variable = np.asarray(variable, dtype=nout.dtype) - - if i in self.inplace_pattern: - odat = inputs[self.inplace_pattern[i]] - odat[...] = variable - storage[0] = odat - - # numpy.real return a view! - if not variable.flags.owndata: - storage[0] = variable.copy() + for s, result in zip(output_storage, ufunc(*inputs)): + s[0] = result @staticmethod def _check_runtime_broadcast(node, inputs): @@ -754,8 +778,7 @@ def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...] def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` # To not request all of them to call prepare_node(), do it here. - # There is no harm if it get called multiple times. - if not hasattr(node.tag, "fake_node"): + if not hasattr(node.tag, "dummy_node"): self.prepare_node(node, None, None, "c") _inames = inames _onames = onames @@ -903,7 +926,7 @@ def _c_all(self, node, nodename, inames, onames, sub): else: fail = sub["fail"] task_code = self.scalar_op.c_code( - node.tag.fake_node, + node.tag.dummy_node, nodename + "_scalar_", [f"{s}_i" for s in _inames], [f"{s}_i" for s in onames], diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index dc30beedf3..4bc843fea6 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -516,15 +516,26 @@ def flatten_nested_add_mul(fgraph, node): return [output] -def elemwise_max_operands_fct(node) -> int: - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs) - if not config.cxx: - return 32 - return 1024 - - class FusionOptimizer(GraphRewriter): - """Graph optimizer that fuses consecutive Elemwise operations.""" + """Graph optimizer that fuses consecutive Elemwise operations. + + Parameters + ---------- + backend : str + The compilation backend: ``"c"`` or ``"numba"``. + ``"c"`` checks that every scalar op has a C implementation before fusing. + ``"numba"`` fuses unconditionally. + Python mode does not benefit from fusion (``frompyfunc`` iteration + in C is faster than fused ``Composite.impl`` per-element overhead). + """ + + def __init__(self, backend: str): + super().__init__() + if backend not in ("c", "numba"): + raise ValueError( + f"Unsupported backend {backend!r}. Expected 'c' or 'numba'." + ) + self.backend = backend def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) @@ -578,18 +589,23 @@ def find_fuseable_subgraphs( This function yields subgraph in reverse topological order so they can be safely replaced one at a time """ - @cache - def elemwise_scalar_op_has_c_code( - node: Apply, optimizer_verbose=config.optimizer_verbose - ) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + if self.backend == "c": + # supports_c_code is expensive, cache results + @cache + def elemwise_scalar_op_is_fuseable( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." + ) + return False + else: + # numba: fuse unconditionally + def elemwise_scalar_op_is_fuseable(node: Apply) -> bool: return True - elif optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False # Create a map from node to a set of fuseable client (successor) nodes # A node and a client are fuseable if they are both single output Elemwise @@ -608,7 +624,7 @@ def elemwise_scalar_op_has_c_code( out_node is not None and len(out_node.outputs) == 1 and isinstance(out_node.op, Elemwise) - and elemwise_scalar_op_has_c_code(out_node) + and elemwise_scalar_op_is_fuseable(out_node) ): continue @@ -621,7 +637,7 @@ def elemwise_scalar_op_has_c_code( len(client.outputs) == 1 and isinstance(client.op, Elemwise) and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) + and elemwise_scalar_op_is_fuseable(client) ) } if out_fuseable_clients: @@ -872,17 +888,10 @@ def elemwise_scalar_op_has_c_code( # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) - max_operands = elemwise_max_operands_fct(None) reason = self.__class__.__name__ nb_fused = 0 nb_replacement = 0 for inputs, outputs in find_fuseable_subgraphs(fgraph): - if (len(inputs) + len(outputs)) > max_operands: - warn( - "Loop fusion failed because the resulting node would exceed the kernel argument limit." - ) - continue - scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables @@ -1172,6 +1181,11 @@ def constant_fold_branches_of_add_mul(fgraph, node): ) # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) +# The outer SequenceDB is backend-agnostic; the actual FusionOptimizer inside +# is registered per-backend (C with "cxx_only", Numba with "numba"). +# Python mode does not benefit from fusion: frompyfunc's C iteration loop is +# faster than fused Composite.impl per-element overhead. +# Shared cleanup rewrites run for any backend that performed fusion. fuse_seqopt = SequenceDB() optdb.register( "elemwise_fusion", @@ -1182,13 +1196,22 @@ def constant_fold_branches_of_add_mul(fgraph, node): "FusionOptimizer", position=49, ) +# C backend fusion: checks that scalar ops have C implementations fuse_seqopt.register( "composite_elemwise_fusion", - FusionOptimizer(), + FusionOptimizer(backend="c"), "fast_run", "fusion", + "cxx_only", position=1, ) +# Numba backend fusion: fuses unconditionally +fuse_seqopt.register( + "numba_composite_elemwise_fusion", + FusionOptimizer(backend="numba"), + "numba", + position=1.5, +) fuse_seqopt.register( "local_useless_composite_outputs", dfs_rewriter(local_useless_composite_outputs), @@ -1201,6 +1224,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): dfs_rewriter(local_careduce_fusion), "fast_run", "fusion", + "cxx_only", position=10, ) fuse_seqopt.register( diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 3167a20149..3675721ebb 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -193,8 +193,8 @@ class MultiOutOp(ScalarOp): def make_node(self, x): return Apply(self, [x], [x.type(), x.type()]) - def perform(self, node, inputs, outputs): - outputs[1][0] = outputs[0][0] = inputs[0] + def impl(self, x): + return x, x def c_code(self, *args): return "dummy" @@ -208,6 +208,35 @@ def c_code(self, *args): assert fn(1.0) == [1.0, 1.0] + def test_composite_without_c_code(self): + """Composite of scalar ops without C code should work for py-only execution.""" + from pytensor.scalar.basic import UnaryScalarOp, float64, upcast_out + + class _NoCodeExp(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.exp(x) + + def output_types(self, types): + return upcast_out(*types) + + class _NoCodeLog(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.log(x) + + def output_types(self, types): + return upcast_out(*types) + + xs = float64("xs") + comp = Composite( + [xs], + [_NoCodeExp(name="no_code_exp")(_NoCodeLog(name="no_code_log")(xs))], + ) + assert comp.impl(2.0) == pytest.approx(2.0) + class TestLogical: def test_gt(self): diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index f121dc9e58..318abb363a 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -1332,6 +1332,46 @@ def inner_fct(mitsot_m2, mitsot_m1, sitsot): sum_of_grads = sum(g.sum() for g in gradients) grad(sum_of_grads, inputs[0]) + def test_high_order_grad_sitsot(self): + """Test higher-order derivatives through a sit-sot scan. + + The L_op of a sit-sot scan creates a mit-mot backward scan where + one buffer position is both read and written. + This is analogous to set_subtensor(x, y, i): the gradient w.r.t. x + must zero out position i, routing gradient only through y. + + A bug in the accumulation logic added a spurious gradient at + the overwritten position, as if the old value also passed + through unchanged. The 2nd derivative graph was wrong but + evaluated correctly (the spurious contribution only affected + the mit-mot output, which is not on the gradient path for + scalar derivatives). The error became visible at the 3rd + derivative, where symbolic differentiation through the wrong + graph produced incorrect values. + """ + # Avoid costly rewrite/compilation of Scans + mode = Mode(linker="py", optimizer=None) + x = pt.scalar("x") + x_val = np.float64(0.95) + ys = scan( + fn=lambda xtm1: xtm1**2, outputs_info=[x], n_steps=4, return_updates=False + ) + y = ys[-1] + + # Sanity check + np.testing.assert_allclose(y.eval({x: x_val}, mode=mode), x_val**16) + + # Evaluate higher order derivatives + deriv = y + for order in range(1, 5): + deriv = grad(deriv, x) + deriv_value = deriv.eval({x: x_val}, mode=mode) + # xs[-1] = x^16, so the n-th derivative is 16!/(16-n)! * x^(16-n) + expected_deriv_value = np.prod((16, 15, 14, 13)[:order]) * x_val ** ( + 16 - order + ) + np.testing.assert_allclose(deriv_value, expected_deriv_value) + def test_grad_dtype_change(self): x = fscalar("x") y = fscalar("y") diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 2c196401a2..07ddf10cba 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -18,7 +18,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.raise_op import assert_op -from pytensor.scalar.basic import Composite, float64 +from pytensor.scalar.basic import Composite, UnaryScalarOp, float64, upcast_out from pytensor.tensor.basic import MakeVector from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import abs as pt_abs @@ -234,6 +234,7 @@ def test_local_useless_expand_dims_in_reshape(): assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)]) +@pytest.mark.skipif(not config.cxx, reason="Fusion requires a C compiler (cxx_only)") class TestFusion: rewrites = RewriteDatabaseQuery( include=[ @@ -242,7 +243,7 @@ class TestFusion: "add_mul_fusion", "inplace", ], - exclude=["cxx_only", "BlasOpt"], + exclude=["BlasOpt"], ) mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) @@ -315,7 +316,7 @@ def test_diamond_graph(): e = c + d fg = FunctionGraph([a], [e], clone=False) - _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + _, nb_fused, nb_replacement, *_ = FusionOptimizer(backend="c").apply(fg) assert nb_fused == 1 assert nb_replacement == 4 @@ -334,7 +335,7 @@ def test_expansion_order(self): e2 = d + b # test both orders fg = FunctionGraph([a], [e1, e2], clone=False) - _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + _, nb_fused, nb_replacement, *_ = FusionOptimizer(backend="c").apply(fg) fg.dprint() assert nb_fused == 1 assert nb_replacement == 3 @@ -1075,7 +1076,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): assert od == o.dtype def test_fusion_35_inputs(self): - r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" + r"""Make sure we can fuse 35 inputs with the C backend.""" inpts = vectors([f"i{i}" for i in range(35)]) # Make an elemwise graph looking like: @@ -1084,16 +1085,16 @@ def test_fusion_35_inputs(self): for idx in range(1, 35): out = sin(inpts[idx] + out) - with config.change_flags(cxx=""): - f = function(inpts, out, mode=self.mode) + f = function(inpts, out, mode=self.mode) - # Make sure they all weren't fused + # With the C backend, everything should be fused composite_nodes = [ node for node in f.maker.fgraph.toposort() if isinstance(getattr(node.op, "scalar_op", None), ps.basic.Composite) ] - assert not any(len(node.inputs) > 31 for node in composite_nodes) + assert len(composite_nodes) == 1 + assert composite_nodes[0].inputs.__len__() == 35 @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_big_fusion(self): @@ -1173,7 +1174,10 @@ def test_fusion_multiout_inplace(self, linker): inp = np.array([0, 1, 2], dtype=config.floatX) res = f(inp) - assert not np.allclose(inp, [0, 1, 2]) + # destroy_map is a permission, not an obligation. + # The C linker writes inplace; the py linker may not (e.g. frompyfunc path). + if linker != "py": + assert not np.allclose(inp, [0, 1, 2]) assert np.allclose(res[0], [1, 2, 3]) assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2])) @@ -1241,7 +1245,8 @@ def test_test_values(self, test_value): pt_all, np.all, marks=pytest.mark.xfail( - reason="Rewrite logic does not support all CAReduce" + strict=False, + reason="Rewrite logic does not support all CAReduce", ), ), ], @@ -1400,7 +1405,7 @@ def test_eval_benchmark(self, benchmark): def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) - opt = FusionOptimizer() + opt = FusionOptimizer(backend="c") def rewrite_func(): fg_clone = fg.clone() @@ -1444,7 +1449,9 @@ def test_joint_circular_dependency(self): for out_order in [(sub, add), (add, sub)]: fgraph = FunctionGraph([x], out_order, clone=True) - _, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph) + _, nb_fused, nb_replaced, *_ = FusionOptimizer(backend="c").apply( + fgraph + ) # (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion assert (nb_fused, nb_replaced) in ((2, 4), (1, 3)) fused_nodes = { @@ -1559,6 +1566,7 @@ def test_local_useless_composite_outputs(): utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) +@pytest.mark.skipif(not config.cxx, reason="Fusion requires a C compiler (cxx_only)") @pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)]) @pytest.mark.parametrize("op, np_op", [(pt.pow, np.power), (pt.add, np.add)]) def test_local_inline_composite_constants(op, np_op, const_shape): @@ -1654,3 +1662,73 @@ def test_InplaceElemwiseOptimizer_bug(): finally: # Restore original value to avoid affecting other tests pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value + + +# Dummy scalar ops without nfunc_spec — same impl as Exp/Log but forces +# the frompyfunc path (no numpy ufunc shortcut). +class _DummyExp(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.exp(x) + + def output_types(self, types): + return upcast_out(*types) + + +class _DummyLog(UnaryScalarOp): + nfunc_spec = None + + def impl(self, x): + return np.log(x) + + def output_types(self, types): + return upcast_out(*types) + + +_dummy_exp = Elemwise(_DummyExp(name="dummy_exp")) +_dummy_log = Elemwise(_DummyLog(name="dummy_log")) + + +class TestPyPerformBenchmarks: + """Benchmarks for the Python Elemwise perform path. + + These verify that: + 1. Ops with nfunc_spec (exp, log) use numpy ufuncs directly (SIMD). + 2. Ops without nfunc_spec use frompyfunc (C iteration loop). + """ + + rewrites = RewriteDatabaseQuery( + include=["fusion", "inplace"], + ) + py_mode = Mode("py", rewrites) + + def test_nfunc_spec(self, benchmark): + """sin(cos(x)) with nfunc_spec uses numpy ufuncs directly.""" + x = dvector("x") + out = pt.sin(pt.cos(x)) + f = function([x], out, mode=self.py_mode, trust_input=True) + + # Should be two separate Elemwise nodes (no py fusion) + elemwise_nodes = [ + n for n in f.maker.fgraph.toposort() if isinstance(n.op, Elemwise) + ] + assert len(elemwise_nodes) == 2 + + data = np.random.random(10_000) + benchmark(f, data) + + def test_no_nfunc_spec(self, benchmark): + """dummy_exp(dummy_log(x)) without nfunc_spec uses frompyfunc.""" + x = dvector("x") + out = _dummy_exp(_dummy_log(x)) + f = function([x], out, mode=self.py_mode, trust_input=True) + + # No py fusion — should be two separate Elemwise nodes + elemwise_nodes = [ + n for n in f.maker.fgraph.toposort() if isinstance(n.op, Elemwise) + ] + assert len(elemwise_nodes) == 2 + + data = np.random.random(10_000) + benchmark(f, data) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 8bd35bca28..95a9122237 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -417,7 +417,10 @@ def test_fill(self): xv = rval((5, 5)) yv = rval((1, 1)) f(xv, yv) - assert (xv == yv).all() + # destroy_map is a permission, not an obligation. + # PerformLinker with frompyfunc may not write inplace. + if not isinstance(linker(), PerformLinker): + assert (xv == yv).all() def test_fill_var(self): x = matrix() @@ -1118,7 +1121,7 @@ def make_node(self, *inputs): outputs = [float_op(), int_op()] return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): + def impl(self, *inputs): raise NotImplementedError() def L_op(self, inputs, outputs, output_gradients):