Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,22 @@ def checkSV(sv_ori, sv_rpl):
# The first len(maker.outputs) variables are original variables.
# The rest are the updates.
out_vars = maker.fgraph.outputs[: len(maker.outputs)]

# Warn if any shared variable with a deleted update is being
# destroyed (mutated inplace) by a node in the graph. In that case
# the shared variable storage will still be mutated even though the
# update is not applied.
if hasattr(maker.fgraph, "destroyers"):
for in_idx in maker.fgraph.update_mapping.values():
input_var = maker.fgraph.inputs[in_idx]
if maker.fgraph.destroyers(input_var):
warnings.warn(
f"Shared variable '{input_var.name}' will still be mutated "
"even though its update is being deleted, because an "
"operation in the graph modifies it inplace.",
UserWarning,
stacklevel=2,
)
else:
out_vars = maker.fgraph.outputs

Expand Down
7 changes: 6 additions & 1 deletion pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
encode_literals,
Expand Down Expand Up @@ -90,7 +93,9 @@ def impl(*inputs_and_core_shapes):
(), # constant_inputs
inputs,
tuple_core_shapes,
None, # size
NO_SIZE,
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
)

return impl
Expand Down
130 changes: 125 additions & 5 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
)
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
Expand All @@ -35,13 +39,13 @@
Mul,
Sub,
TrueDiv,
get_scalar_type,
maximum,
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


Expand Down Expand Up @@ -312,8 +316,7 @@ def axis_apply_fn(x):

@register_funcify_and_cache_key(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_node = op.make_scalar_node(*node.inputs)
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
op.scalar_op,
node=scalar_node,
Expand Down Expand Up @@ -368,8 +371,10 @@ def impl(*inputs):
True, # allow_core_scalar
(), # constant_inputs
inputs,
core_output_shapes, # core_shapes
None, # size
core_output_shapes,
NO_SIZE,
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
)

return impl
Expand All @@ -390,6 +395,121 @@ def impl(*inputs):
return elemwise, elemwise_key


@register_funcify_and_cache_key(IndexedElemwise)
def numba_funcify_IndexedElemwise(op, node, **kwargs):
"""Generate fused Elemwise Numba code with indexed reads and updates.

Reads indexed_inputs/indexed_outputs specs stored on the Op by the
rewriting pass, and generates a single vectorized loop with indirect
indexing.

Outer inputs are ordered as::

[elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...]
"""
[elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)]

indexed_inputs = op.indexed_inputs
indexed_outputs = op.indexed_outputs

# Derive update_out_set and inc_outputs from stored specs
update_out_set = frozenset(
out_idx
for entry in indexed_outputs
if entry is not None
for out_idx in entry[0]
)
inc_outputs = frozenset(
out_idx
for entry in indexed_outputs
if entry is not None
for out_idx in entry[0]
if entry[1] == "inc"
)

# --- Scalar function and core op --------------------------------------
scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs)
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
elemwise_node.op.scalar_op, node=scalar_node, **kwargs
)

nin_elemwise = len(elemwise_node.inputs)
nout = len(elemwise_node.outputs)
core_op_fn = store_core_outputs(
scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs
)

# --- Broadcast and type encodings -------------------------------------
# Use each input's result broadcastable (not the source's).
# For indexed reads, the result has the index dim(s) substituted.
# For multi-index, the result has fewer dims than the source.
input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs)
# Use Elemwise output bc (loop dims) for ALL outputs — including updates.
# The fix-up block in _vectorized corrects the actual output/return types
# for update outputs to match the target buffer.
output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs)
output_dtypes = tuple(out.type.dtype for out in node.outputs)
# Filter out elemwise inplace for indexed-update outputs (they use update targets)
inplace_pattern = tuple(
(out_idx, inp_idx)
for out_idx, inp_idx in elemwise_node.op.inplace_pattern.items()
if out_idx not in update_out_set
)
core_output_shapes = tuple(() for _ in range(nout))

input_bc_patterns_enc = encode_literals(input_bc_patterns)
output_bc_patterns_enc = encode_literals(output_bc_patterns)
output_dtypes_enc = encode_literals(output_dtypes)
inplace_pattern_enc = encode_literals(inplace_pattern)
indexed_inputs_enc = encode_literals(indexed_inputs)
indexed_outputs_enc = encode_literals(indexed_outputs)

def indexed_elemwise_fn(*outer_inputs):
raise NotImplementedError(
"IndexedElemwise cannot be evaluated in Python (non-JIT) mode."
)

@overload(indexed_elemwise_fn, jit_options=_jit_options)
def ov_indexed_elemwise_fn(*outer_inputs):
def impl(*outer_inputs):
return _vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
outer_inputs,
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
)

return impl

cache_version = (0, 5)
if scalar_cache_key is None:
key = None
else:
key = str(
(
type(op),
"IndexedElemwise",
cache_version,
inplace_pattern,
input_bc_patterns,
indexed_inputs,
indexed_outputs,
scalar_cache_key,
)
)
key = sha256(key.encode()).hexdigest()

return indexed_elemwise_fn, key


@register_funcify_and_cache_key(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
Expand Down
7 changes: 6 additions & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
encode_literals,
Expand Down Expand Up @@ -474,9 +477,11 @@ def impl(core_shape, rng, size, *dist_params):
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None
NO_SIZE
if size_len is None
else numba_ndarray.to_fixed_tuple(size, size_len),
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
)
return rng, draws

Expand Down
Loading
Loading