Skip to content

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015

Draft
ricardoV94 wants to merge 7 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion
Draft

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015
ricardoV94 wants to merge 7 commits intopymc-devs:v3from
ricardoV94:gather_scatter_fusion

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 29, 2026

Summary

Introduce IndexedElemwise, an OpFromGraph that wraps AdvancedSubtensor + Elemwise + AdvancedIncSubtensor subgraphs so the Numba backend can generate a single loop with indirect indexing, avoiding materializing AvancedSubtensor input arrays, and writing directly on the output buffer, doing the job of AdvancedIncSubtensor in the same loop, without having to loop again through the intermediate elemwise output

Commit 1 fuses indexed reads (AdvancedSubtensor1 on inputs).
Commit 2 fuses indexed updates (AdvancedIncSubtensor1 on outputs).
Commit 3 extends to AdvancedSubtensor inputs, on arbitrary (1d) indexed (consecutive) axes

Motivation

In hierarchical models with mu = beta[idx] * x + ..., the logp+gradient graph combines indexed reads and indexed updates in the same Elemwise (the forward reads county-level parameters via an index, and the gradient accumulates back into county-level buffers via the same index).

A simler

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

numba_mode = get_mode("NUMBA")
numba_mode_before = numba_mode.excluding("fuse_indexed_elemwise")

x = pt.vector("x")
idx = pt.vector("idx", dtype=int)
value = pt.vector("value")

y = pt.zeros(100)
out = ((x[idx] - value) ** 2).sum()
grad_wrt_x = pt.grad(out, x)
fn_before = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode_before, trust_input=True)
fn_before.dprint(print_op_info=True, print_destroy_map=True)
# Sum{axes=None} [id A] 5
#  └─ Composite{...}.0 [id B] d={0: [0]} 1
#     ├─ AdvancedSubtensor1 [id C] 0
#     │  ├─ x [id D]
#     │  └─ idx [id E]
#     └─ value [id F]
# AdvancedIncSubtensor1{inplace,inc} [id G] d={0: [0]} 4
#  ├─ Alloc [id H] 3
#  │  ├─ [0.] [id I]
#  │  └─ Shape_i{0} [id J] 2
#  │     └─ x [id D]
#  ├─ Composite{...}.1 [id B] d={0: [0]} 1
#  │  └─ ···
#  └─ idx [id E]

# Inner graphs:

# Composite{...} [id B] d={0: [0]}
#  ← sqr [id K] 'o0'
#     └─ sub [id L] 't5'
#        ├─ i0 [id M]
#        └─ i1 [id N]
#  ← mul [id O] 'o1'
#     ├─ 2.0 [id P]
#     └─ sub [id L] 't5'
#        └─ ···

fn = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode, trust_input=True)
fn.dprint(print_op_info=True, print_destroy_map=True)

# Sum{axes=None} [id A] 3
#  └─ IndexedElemwise{Composite{...}}.0 [id B] d={1: [3]} 2
#     ├─ x [id C] (indexed read (idx_0))
#     ├─ value [id D]
#     ├─ idx [id E] (idx_0)
#     └─ Alloc [id F] 1 (buf_0)
#        ├─ [0.] [id G]
#        └─ Shape_i{0} [id H] 0
#           └─ x [id C]
# IndexedElemwise{Composite{...}}.1 [id B] d={1: [3]} 2 (indexed inc (buf_0, idx_0))
#  └─ ···

# Inner graphs:

# IndexedElemwise{Composite{...}} [id B] d={1: [3]}
#  ← Composite{...}.0 [id I]
#     ├─ AdvancedSubtensor1 [id J]
#     │  ├─ *0-<Vector(float64, shape=(?,))> [id K]
#     │  └─ *2-<Vector(int64, shape=(?,))> [id L]
#     └─ *1-<Vector(float64, shape=(?,))> [id M]
#  ← AdvancedIncSubtensor1{inplace,inc} [id N] d={0: [0]}
#     ├─ *3-<Vector(float64, shape=(?,))> [id O]
#     ├─ Composite{...}.1 [id I]
#     │  └─ ···
#     └─ *2-<Vector(int64, shape=(?,))> [id L]

# Composite{...} [id I]
#  ← sqr [id P] 'o0'
#     └─ sub [id Q] 't0'
#        ├─ i0 [id R]
#        └─ i1 [id S]
#  ← mul [id T] 'o1'
#     ├─ 2.0 [id U]
#     └─ sub [id Q] 't0'
#        └─ ···

x_test = np.arange(15, dtype="float64")
idx_test = np.random.randint(15, size=(10_000,))
value_test = np.random.normal(size=idx_test.shape)

logp_before, dlogp_before = fn_before(x_test, value_test, idx_test)
logp, dlogp = fn(x_test, value_test, idx_test)
np.testing.assert_allclose(logp_before, logp)
np.testing.assert_allclose(dlogp_before, dlogp)

%timeit fn_before(x_test, value_test, idx_test)  # 29.4 μs ± 2.57 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fn(x_test, value_test, idx_test)  # 13.8 μs ± 136 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Next step would be to also fuse the sum directly on the elemwise, so we end up with a single loop over the data. This is important as the sum can easily break our fusion, as we don't fuse if the elemwise output is needed elsewhere (like in a sum).

@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 6d875d8 to 0ad6e2e Compare March 29, 2026 18:14
@ricardoV94 ricardoV94 changed the title Numba: fuse AdvancedSubtensor+Elemwise Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor Mar 29, 2026
updates:
- [github.com/sphinx-contrib/sphinx-lint: v1.0.0 → v1.0.2](sphinx-contrib/sphinx-lint@v1.0.0...v1.0.2)
- [github.com/astral-sh/ruff-pre-commit: v0.14.0 → v0.15.8](astral-sh/ruff-pre-commit@v0.14.0...v0.15.8)
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 9e32400 to e3c59c6 Compare March 31, 2026 18:58
Fuse single-client AdvancedSubtensor1 nodes into Elemwise loops,
replacing indirect array reads with a single iteration loop that
uses index arrays for input access.

Before (2 nodes):
  temp = x[idx]                    # AdvancedSubtensor1, shape (919,)
  result = temp + y                # Elemwise

After (1 fused loop, x is read directly via idx):
  for k in range(919):
      result[k] = x[idx[k]] + y[k]

- Introduce IndexedElemwise Op (in rewriting/indexed_elemwise.py)
- Add FuseIndexedElemwise rewrite with SequenceDB
- Merge _vectorized intrinsics into one with NO_SIZE/NO_INDEXED sentinels
- Fix Numba missing getitem(0d_array, Ellipsis)
- Index arrays participate in iter_shape with correct static bc
- zext for unsigned index types
- Add op_debug_information for dprint(print_op_info=True)
- Add correctness tests and benchmarks
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 4156818 to a5d7d03 Compare March 31, 2026 22:57
Extend the IndexedElemwise fusion to also absorb
AdvancedIncSubtensor1 (indexed set/inc) on the output side.

Before (3 nodes):
  temp = Elemwise(x[idx], y)               # shape (919,)
  result = IncSubtensor(target, temp, idx)  # target shape (85,)

After (1 fused loop, target is an input):
  for k in range(919):
      target[idx[k]] += scalar_fn(x[idx[k]], y[k])

- FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers
- Reject fusion when val broadcasts against target's non-indexed axes
- store_core_outputs supports inc mode via o[...] += val
- Inner fgraph always uses inplace IncSubtensor
- op_debug_information shows buf_N / idx_N linkage
- Add indexed-update tests, broadcast guard test, and benchmarks
Support AdvancedSubtensor on any axis (not just axis 0) and multi-index
patterns like x[idx_row, idx_col] where multiple 1D index arrays address
consecutive source axes.

Arbitrary axis:
  x[:, idx] + y → fused loop with indirect indexing on axis 1

Multi-index:
  x[idx0, idx1] + y → out[i, j] = x[idx0[i], idx1[i], j] + y[i, j]

- Add undo_take_dimshuffle_for_fusion pre-fusion rewrite
- Generalize indexed_inputs encoding: ((positions, axis, idx_bc), ...)
- input_read_spec uses tuple of (idx_k, axis) pairs per input
- source_input_types for array struct access, input_types (effective)
  for core_ndim / _compute_vectorized_types
- n_index_loop_dims = max(idx.ndim for group) for future ND support
- Index arrays participate in iter_shape with correct per-index static bc
- Reject boolean indices in AdvancedSubtensor fusion paths
- Reject fusion when val broadcasts against target's non-indexed axes
- Add correctness, broadcast, and shape validation tests
Extend FusionOptimizer to merge independent subgraphs that share
inputs but have no producer-consumer edge (siblings like f(x) and g(x)).
The eager expansion only walks producer-consumer edges, missing these.

Also extract InplaceGraphOptimizer.try_inplace_on_node helper and
_insert_sorted_subgraph to deduplicate insertion-point logic.
Support multidimensional integer indices (e.g. exp(x[mat_idx]) where
mat_idx is 2D) in IndexedElemwise fusion.  An ND index contributes
multiple loop dimensions — one per index dimension.

Rewrite:
- Add undo_take_reshape_for_fusion: undoes the Reshape+flatten pattern
  that transform_take applies for ND indices, recovering the original
  AdvancedSubtensor(source, mat_idx) form for fusion.
  Handles both axis=0 and axis>0 (with DimShuffle wrapping).

Fusion:
- _get_indexed_read_info now accepts idx.ndim >= 1 (was == 1)
- Numba dispatch likewise accepts ND indices in AdvancedSubtensor

Codegen:
- idx_load_axes: tuple of tuples (was flat ints), each index array
  loads from idx_ndim loop counters
- Indirect index loading uses multiple loop counters for ND arrays
- iter_shape construction maps ND index dims to correct loop dims
- result_dim advancement accounts for ND index expansion
- Effective input types handle source ndim != result ndim
- Bump IndexedElemwise cache version
Generalize the write (update) side of IndexedElemwise fusion to support
AdvancedIncSubtensor with multiple index arrays on consecutive axes.

Rewrite:
- find_indexed_update_consumers now detects both AdvancedIncSubtensor1
  and AdvancedIncSubtensor, extracting idx_axis_pairs per update
- _get_indexed_update_info helper mirrors _get_indexed_read_info
- Broadcast guard generalized for non-axis-0 indexed axes
- Scatter construction supports AdvancedIncSubtensor (inplace)

Dispatch:
- Update analysis detects AdvancedIncSubtensor in inner fgraph
- indexed_outputs built from per-index axis pairs
- output_bc_patterns use Elemwise output bc (loop dims) so all
  bc patterns have matching batch_ndim

Codegen:
- n_update_targets counts unique output indices (not per-index entries)
- update_out_to_target deduplicates for shared targets
- Write-side idx_load_axes uses group min_axis from write groups
- output_update_spec recomputes core_ndim from target + spec
- Write-only index arrays skip iter_shape validation
- Bump cache version
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch from a5d7d03 to 473ada3 Compare March 31, 2026 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant