Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar!
"""

import abc
import builtins
import math
from collections.abc import Callable
Expand All @@ -36,7 +37,6 @@
from pytensor.utils import (
apply_across_args,
difference,
to_return_values,
)


Expand Down Expand Up @@ -1250,8 +1250,9 @@ def perform(self, node, inputs, output_storage):
):
storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype)

@abc.abstractmethod
def impl(self, *inputs):
raise MethodNotDefined("impl", type(self), self.__class__.__name__)
raise NotImplementedError()

def grad(self, inputs, output_gradients):
raise MethodNotDefined("grad", type(self), self.__class__.__name__)
Expand Down Expand Up @@ -3840,7 +3841,7 @@ class Real(UnaryScalarOp):

"""

# numpy.real(float32) return a view on the inputs.
# numpy.real(float32) return a view on the inputs, which ain't good for elemwise.
# nfunc_spec = ('real', 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -4054,39 +4055,27 @@ def inner_outputs(self):

@property
def py_perform_fn(self):
if hasattr(self, "_py_perform_fn"):
"""Compiled Python function that chains inner ops' ``impl`` methods.

Returns a callable that takes scalar inputs and returns a tuple of outputs.
"""
try:
return self._py_perform_fn
except AttributeError:
pass

from pytensor.link.utils import fgraph_to_python

def python_convert(op, node=None, **kwargs):
assert node is not None

n_outs = len(node.outputs)
def impl_convert(op, node=None, **kwargs):
return op.impl

if n_outs > 1:

def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)

else:

def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]

return _perform

self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
self._py_perform_fn = fgraph_to_python(self.fgraph, impl_convert)
return self._py_perform_fn

def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage)
ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1:
ret = tuple(ret)
ret = self.py_perform_fn(*inputs)
if self.nout == 1:
return ret[0]
return ret

def c_code_cache_version(self):
Expand Down Expand Up @@ -4311,9 +4300,7 @@ def make_node(self, *inputs):
return node

def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
# zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, outputs):
for storage, out_val in zip(output_storage, self.py_perform_fn(*inputs)):
storage[0] = out_val

def grad(self, inputs, output_grads):
Expand Down
136 changes: 80 additions & 56 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,12 +2437,19 @@ def connection_pattern(self, node):
return connection_pattern

def L_op(self, inputs, outs, dC_douts):
# `grad_step` equals the number of steps the original scan node has
# done (if the original scan is a while loop than this number is the
# length of the output sequence)
# We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot
# Computes the gradient of this Scan by constructing a new backward Scan
# that runs in reverse. The method:
# 1. Differentiates the inner function symbolically (compute_all_gradients)
# 2. Adds accumulation terms for state inputs at preserved buffer positions
# 3. Builds reversed sequences from the forward outputs
# 4. Converts all recurrent states (sit-sot, mit-sot, mit-mot) into mit-mot
# form in the backward scan (initialized with output gradients, accumulate
# total gradients after evaluation)
# 5. Constructs and runs the backward Scan, then re-orders its outputs

# Determine the number of gradient steps from the output shapes (not from
# inputs[0] directly, because while-loop scans may execute fewer steps than
# the allocated buffer size).
info = self.info
if info.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
Expand All @@ -2457,8 +2464,7 @@ def L_op(self, inputs, outs, dC_douts):
if info.as_while:
n_steps = outs[0].shape[0]

# Restrict the number of grad steps according to
# self.truncate_gradient
# Restrict the number of grad steps according to self.truncate_gradient
if self.truncate_gradient != -1:
grad_steps = minimum(grad_steps, self.truncate_gradient)

Expand Down Expand Up @@ -2540,13 +2546,11 @@ def compute_all_gradients(known_grads):
]
gmp = {}

# Required in case there is a pair of variables X and Y, with X
# used to compute Y, for both of which there is an external
# gradient signal. Without this, the total gradient signal on X
# will be the external gradient signalknown_grads[X]. With this,
# it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal
# to X.
# The .copy() creates fresh variable nodes so that grad() treats them
# as new outputs "equal to" the originals, rather than matching them by
# identity to variables already in the graph. This forces grad() to
# propagate the known_grads values backward through the computation
# instead of short-circuiting at a wrt target.
known_grads = {k.copy(): v for (k, v) in known_grads.items()}

grads = grad(
Expand Down Expand Up @@ -2588,17 +2592,15 @@ def compute_all_gradients(known_grads):
Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder)

# Different processing based on whether Xt is a nitsot output
# or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
# Different processing based on whether Xt is a nitsot output or not.
# NOTE : This cannot be done by using "if Xt not in self.inner_nitsot_outs(self_outputs)"
# because the exact same variable can be used as multiple outputs.
if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end:
# What we do here is loop through dC_douts and collect all
# loop through dC_douts and collect all
# those that are connected to the specific one and do an
# upcast on all of their dtypes to get the dtype for this
# specific output. Deciding if the gradient with this
# specific previous step is defined or not is done somewhere
# else.
# specific previous step is defined or not is done somewhere else.
dtypes = []
for pos, inp in enumerate(states):
if inp in graph_inputs([Xt]):
Expand Down Expand Up @@ -2637,9 +2639,9 @@ def compute_all_gradients(known_grads):
continue

# Just some trouble to avoid a +0
if diff_outputs[i] in known_grads:
try:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
except KeyError:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1

Expand All @@ -2655,6 +2657,9 @@ def compute_all_gradients(known_grads):
)
else:
disconnected_dC_dinps_t[dx] = False
# Replace inner output subexpressions with placeholders wired to the
# saved forward values, so the backward scan reuses them instead of
# recomputing them. See forced_replace docstring for details.
for Xt, Xt_placeholder in zip(
diff_outputs[info.n_mit_mot_outs :], Xts, strict=True
):
Expand All @@ -2663,21 +2668,20 @@ def compute_all_gradients(known_grads):

# construct dX_dtm1
dC_dXtm1s = []
n_internal_recurrent_states = sum(
len(t)
for t in chain(
info.mit_mot_in_slices,
info.mit_sot_in_slices,
info.sit_sot_in_slices,
)
)
for pos, x in enumerate(dC_dinps_t[info.n_seqs :]):
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
# Get the index of the first inner input corresponding to the pos-ieth inner input state
idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos]

# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state = pos < sum(
len(t)
for t in chain(
info.mit_mot_in_slices,
info.mit_sot_in_slices,
info.sit_sot_in_slices,
)
)
# Check if the pos-th input is associated with one of the recurrent states
x_is_state = pos < n_internal_recurrent_states

if x_is_state and len(idxs) > 0:
opos = idxs[0]
Expand All @@ -2687,7 +2691,26 @@ def compute_all_gradients(known_grads):
else:
dC_dXtm1s.append(safe_new(x))

# Skip accumulation for "overlapping" mit-mot taps.
# A mit-mot tap "overlaps" when the same tap index appears in both the input
# and output slices of a single mit-mot state. This means the output *overwrites*
# the input at that buffer position — analogous to set_subtensor(x, y, i).
# The gradient of an overwrite must zero out the direct pass-through from the
# old value; the only gradient path is through the output expression that replaced
# it (already captured by compute_all_gradients via known_grads).
overlapping_taps = set()
dx_offset = 0
for idx in range(info.n_mit_mot):
in_taps = info.mit_mot_in_slices[idx]
out_taps = info.mit_mot_out_slices[idx]
for k, tap in enumerate(in_taps):
if tap in out_taps:
overlapping_taps.add(dx_offset + k)
dx_offset += len(in_taps)

for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if dx in overlapping_taps:
continue # gradient truncates here
if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
Expand Down Expand Up @@ -2761,8 +2784,7 @@ def compute_all_gradients(known_grads):
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]

# Restrict the length of the outer sequences to the number of grad
# steps
# Restrict the length of the outer sequences to the number of grad steps
outer_inp_seqs = [s_[:grad_steps] for s_ in outer_inp_seqs]

inner_inp_seqs = self.inner_seqs(self_inputs)
Expand All @@ -2771,7 +2793,14 @@ def compute_all_gradients(known_grads):
inner_inp_seqs += self.inner_sitsot(self_inputs)
inner_inp_seqs += self.inner_nitsot_outs(dC_dXts)
inner_inp_seqs += Xts
# mitmot
# Build backward scan's mit-mot states.
# Every forward recurrent state (sit-sot, mit-sot, mit-mot) becomes
# a mit-mot in the backward scan. The conversion negates the taps:
# forward output tap t → backward input tap -t (gradient signal)
# forward input tap t → backward output tap -t (gradient to propagate)
# Each backward output tap also needs a backward input tap at the same
# position to carry the accumulated gradient (the recurrence). If one
# already exists from the first rule, they share the buffer slot.
outer_inp_mitmot = []
inner_inp_mitmot = []
inner_out_mitmot = []
Expand Down Expand Up @@ -2810,8 +2839,8 @@ def compute_all_gradients(known_grads):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])

if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
# We cannot use Null in the inner graph,
# so we use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX)
)
Expand Down Expand Up @@ -2919,9 +2948,8 @@ def compute_all_gradients(known_grads):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# Cannot use dC_dinps_t[ins_pos].dtype, so we use
# floatX instead, as it is a dummy value that will not
# be used anyway.
# Cannot use dC_dinps_t[ins_pos].dtype, so we use floatX instead,
# as it is a dummy value that will not be used anyway.
outer_inp_mitmot.append(
pt.zeros(outs[idx + offset].shape, dtype=config.floatX)
)
Expand All @@ -2933,8 +2961,8 @@ def compute_all_gradients(known_grads):
)

if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead.
# We cannot use Null in the inner graph,
# so we use a zero tensor of the appropriate shape instead.
inner_out_mitmot.append(
pt.zeros(diff_inputs[ins_pos].shape, dtype=config.floatX)
)
Expand Down Expand Up @@ -2974,8 +3002,7 @@ def compute_all_gradients(known_grads):
through_untraced = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
# Replace the inner output with a zero tensor of the right shape
inner_out_sitsot[_p] = pt.zeros(
diff_inputs[ins_pos + _p].shape, dtype=config.floatX
)
Expand All @@ -2993,8 +3020,7 @@ def compute_all_gradients(known_grads):
through_untraced = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
# Replace the inner output with a zero tensor of the right shape
inner_out_nitsot[_p] = pt.zeros(
diff_inputs[_p].shape, dtype=config.floatX
)
Expand Down Expand Up @@ -3089,9 +3115,8 @@ def compute_all_gradients(known_grads):
)
):
if t == "connected":
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
# If the forward scan is in as_while mode, we need to pad the gradients,
# so that they match the size of the input sequences.
if info.as_while:
n_zeros = inputs[0] - n_steps
shp = (n_zeros,)
Expand All @@ -3117,9 +3142,8 @@ def compute_all_gradients(known_grads):
end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)):
if t == "connected":
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
# If the forward scan is in as_while mode, we need to pad the gradients,
# so that they match the size of the input sequences.
if info.as_while:
n_zeros = inputs[0] - grad_steps
shp = (n_zeros,)
Expand Down
6 changes: 6 additions & 0 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.

This rewrite must only run at compilation time, after grad() has already
built the backward scan. The backward scan needs all intermediate forward
states as sequence inputs (to evaluate f'(x[t])). If this rewrite truncates
buffers before grad() is called, the gradient will be silently wrong.
TODO: Use a subclass that raises explicitly on `L_op`

Paramaters
----------
backend_supports_output_pre_allocation: bool
Expand Down
Loading